zarr_datafusion/optimizer/
count_optimization.rs

1//! Custom optimizer rule to replace count aggregates with constants from statistics
2//!
3//! This rule transforms:
4//!   `SELECT count(*) FROM table` → `SELECT 700`
5//!   `SELECT count(col) FROM table` → `SELECT 700` (when null_count = 0)
6//!   `SELECT count(*), count(col1), count(col2) FROM table` → `SELECT 700, 700, 700`
7//!
8//! The optimization is only applied when:
9//! 1. All aggregate expressions are count() functions (no sum, avg, etc.)
10//! 2. No GROUP BY clause
11//! 3. No WHERE filters
12//! 4. Statistics provide exact num_rows
13//! 5. For count(col), the column has exact null_count
14
15use datafusion::common::stats::Precision;
16use datafusion::common::tree_node::Transformed;
17use datafusion::common::Result;
18use datafusion::datasource::source_as_provider;
19use datafusion::logical_expr::expr::AggregateFunction;
20use datafusion::logical_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableScan};
21use datafusion::optimizer::optimizer::ApplyOrder;
22use datafusion::optimizer::OptimizerRule;
23use datafusion::prelude::lit;
24use std::sync::Arc;
25use tracing::{debug, info, trace, warn};
26
27/// Optimizer rule that replaces count() with constants from table statistics
28#[derive(Debug, Default)]
29pub struct CountStatisticsRule;
30
31impl CountStatisticsRule {
32    pub fn new() -> Self {
33        Self
34    }
35}
36
37impl OptimizerRule for CountStatisticsRule {
38    fn name(&self) -> &str {
39        "count_statistics"
40    }
41
42    fn apply_order(&self) -> Option<ApplyOrder> {
43        Some(ApplyOrder::BottomUp)
44    }
45
46    fn supports_rewrite(&self) -> bool {
47        true
48    }
49
50    fn rewrite(
51        &self,
52        plan: LogicalPlan,
53        _config: &dyn datafusion::optimizer::OptimizerConfig,
54    ) -> Result<Transformed<LogicalPlan>> {
55        // Only optimize Aggregate nodes
56        let LogicalPlan::Aggregate(aggregate) = &plan else {
57            trace!("Skipping non-Aggregate node");
58            return Ok(Transformed::no(plan));
59        };
60
61        debug!(
62            group_by_count = aggregate.group_expr.len(),
63            aggr_count = aggregate.aggr_expr.len(),
64            "Evaluating Aggregate node for count optimization"
65        );
66
67        // Must have no GROUP BY (simple aggregate)
68        if !aggregate.group_expr.is_empty() {
69            debug!(
70                group_by_count = aggregate.group_expr.len(),
71                "Skipping: has GROUP BY clause"
72            );
73            return Ok(Transformed::no(plan));
74        }
75
76        // Must have at least one aggregate expression
77        if aggregate.aggr_expr.is_empty() {
78            debug!("Skipping: no aggregate expressions");
79            return Ok(Transformed::no(plan));
80        }
81
82        // Input must be a simple TableScan with no filters
83        let input = aggregate.input.as_ref();
84        let table_scan = match unwrap_to_table_scan(input) {
85            Some(scan) => scan,
86            None => {
87                debug!("Skipping: could not find TableScan in input");
88                return Ok(Transformed::no(plan));
89            }
90        };
91
92        let table_name = table_scan.table_name.to_string();
93        debug!(table = %table_name, "Found TableScan");
94
95        // Must have no filters
96        if !table_scan.filters.is_empty() {
97            debug!(
98                table = %table_name,
99                filter_count = table_scan.filters.len(),
100                "Skipping: query has WHERE filters"
101            );
102            return Ok(Transformed::no(plan));
103        }
104
105        // Get TableProvider from TableSource to access statistics
106        let provider = match source_as_provider(&table_scan.source) {
107            Ok(p) => p,
108            Err(e) => {
109                warn!(
110                    table = %table_name,
111                    error = %e,
112                    "Could not get TableProvider from TableSource"
113                );
114                return Ok(Transformed::no(plan));
115            }
116        };
117
118        // Get statistics from the table
119        let statistics = match provider.statistics() {
120            Some(stats) => stats,
121            None => {
122                debug!(
123                    table = %table_name,
124                    "Skipping: table does not provide statistics"
125                );
126                return Ok(Transformed::no(plan));
127            }
128        };
129
130        // num_rows must be exact
131        let num_rows: usize = match statistics.num_rows {
132            Precision::Exact(n) => {
133                debug!(table = %table_name, num_rows = n, "Got exact row count from statistics");
134                n
135            }
136            Precision::Inexact(n) => {
137                debug!(
138                    table = %table_name,
139                    approx_rows = n,
140                    "Skipping: row count is inexact"
141                );
142                return Ok(Transformed::no(plan));
143            }
144            Precision::Absent => {
145                debug!(table = %table_name, "Skipping: row count not available");
146                return Ok(Transformed::no(plan));
147            }
148        };
149
150        let schema = provider.schema();
151
152        // Process all aggregate expressions - all must be count() functions
153        let mut count_values: Vec<(String, i64)> = Vec::new();
154
155        for aggr_expr in &aggregate.aggr_expr {
156            match extract_count_info(aggr_expr) {
157                Some((is_count_star, column_name)) => {
158                    // Get the alias name for this expression
159                    let alias_name = get_expr_alias(aggr_expr, &aggregate.schema);
160
161                    let count_value = if is_count_star {
162                        debug!(alias = %alias_name, "Found count(*)");
163                        num_rows as i64
164                    } else if let Some(ref col_name) = column_name {
165                        // Find column index and get null count
166                        let col_idx = schema.fields().iter().position(|f| f.name() == col_name);
167
168                        if let Some(idx) = col_idx {
169                            let col_stats = &statistics.column_statistics;
170                            if idx < col_stats.len() {
171                                match col_stats[idx].null_count {
172                                    Precision::Exact(0) => {
173                                        debug!(
174                                            alias = %alias_name,
175                                            column = %col_name,
176                                            "Column has no nulls, count = num_rows"
177                                        );
178                                        num_rows as i64
179                                    }
180                                    Precision::Exact(nulls) => {
181                                        let count = num_rows.saturating_sub(nulls) as i64;
182                                        debug!(
183                                            alias = %alias_name,
184                                            column = %col_name,
185                                            null_count = nulls,
186                                            count,
187                                            "Column has nulls, count = num_rows - null_count"
188                                        );
189                                        count
190                                    }
191                                    Precision::Inexact(nulls) => {
192                                        debug!(
193                                            column = %col_name,
194                                            approx_nulls = nulls,
195                                            "Skipping: null count is inexact"
196                                        );
197                                        return Ok(Transformed::no(plan));
198                                    }
199                                    Precision::Absent => {
200                                        debug!(
201                                            column = %col_name,
202                                            "Skipping: null count not available"
203                                        );
204                                        return Ok(Transformed::no(plan));
205                                    }
206                                }
207                            } else {
208                                debug!(
209                                    column = %col_name,
210                                    "Skipping: column statistics not available"
211                                );
212                                return Ok(Transformed::no(plan));
213                            }
214                        } else {
215                            debug!(column = %col_name, "Skipping: column not found in schema");
216                            return Ok(Transformed::no(plan));
217                        }
218                    } else {
219                        debug!("Skipping: column name not available");
220                        return Ok(Transformed::no(plan));
221                    };
222
223                    count_values.push((alias_name, count_value));
224                }
225                None => {
226                    // Not a count function - can't optimize this aggregate
227                    debug!("Skipping: found non-count aggregate function");
228                    return Ok(Transformed::no(plan));
229                }
230            }
231        }
232
233        info!(
234            table = %table_name,
235            count_expressions = count_values.len(),
236            values = ?count_values,
237            "Count optimization applied - replacing with constants"
238        );
239
240        create_multi_count_plan(&count_values)
241    }
242}
243
244/// Extract count function info from an expression
245/// Returns Some((is_count_star, column_name)) if it's a count function, None otherwise
246fn extract_count_info(expr: &Expr) -> Option<(bool, Option<String>)> {
247    let inner = unwrap_alias(expr);
248    match inner {
249        Expr::AggregateFunction(AggregateFunction { func, params }) => {
250            if func.name() != "count" {
251                trace!(func_name = %func.name(), "Not a count function");
252                return None;
253            }
254            // Check if count(*) or count(column)
255            match params.args.first() {
256                Some(Expr::Literal(_, _)) => Some((true, None)), // count(1) or count(null)
257                Some(Expr::Column(col)) => Some((false, Some(col.name.clone()))),
258                None => Some((true, None)), // count() with no args = count(*)
259                _ => {
260                    debug!("Unsupported count argument type");
261                    None
262                }
263            }
264        }
265        _ => {
266            trace!("Not an AggregateFunction expression");
267            None
268        }
269    }
270}
271
272/// Get the alias name for an expression
273fn get_expr_alias(expr: &Expr, _schema: &Arc<datafusion::common::DFSchema>) -> String {
274    // Try to get alias from the expression itself
275    if let Expr::Alias(alias) = expr {
276        return alias.name.clone();
277    }
278
279    // Use the expression's schema name (e.g., "count(*)" or "count(col)")
280    expr.schema_name().to_string()
281}
282
283/// Unwrap Alias expressions to get the inner expression
284fn unwrap_alias(expr: &Expr) -> &Expr {
285    match expr {
286        Expr::Alias(alias) => unwrap_alias(&alias.expr),
287        other => other,
288    }
289}
290
291/// Unwrap projections to find the underlying TableScan
292/// Returns None if there's a Filter (count would change after filtering)
293fn unwrap_to_table_scan(plan: &LogicalPlan) -> Option<&TableScan> {
294    match plan {
295        LogicalPlan::TableScan(scan) => Some(scan),
296        LogicalPlan::Projection(Projection { input, .. }) => unwrap_to_table_scan(input),
297        LogicalPlan::SubqueryAlias(alias) => unwrap_to_table_scan(&alias.input),
298        LogicalPlan::Filter(_) => {
299            // Filter changes the count - can't optimize
300            debug!("Found Filter node - count would change after filtering");
301            None
302        }
303        other => {
304            debug!(node_type = %other.display(), "Unsupported node type in input");
305            None
306        }
307    }
308}
309
310/// Create a new plan that returns multiple count values as constants
311fn create_multi_count_plan(count_values: &[(String, i64)]) -> Result<Transformed<LogicalPlan>> {
312    // Create literal expressions for each count value with their aliases
313    let exprs: Vec<Expr> = count_values
314        .iter()
315        .map(|(alias, value)| lit(*value).alias(alias))
316        .collect();
317
318    // Create an empty relation (no input needed)
319    let empty = LogicalPlan::EmptyRelation(EmptyRelation {
320        produce_one_row: true,
321        schema: Arc::new(arrow::datatypes::Schema::empty().try_into()?),
322    });
323
324    // Create projection with all the constants
325    let projection = LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(empty))?);
326
327    Ok(Transformed::yes(projection))
328}