zarr_datafusion/optimizer/
count_optimization.rs1use datafusion::common::stats::Precision;
16use datafusion::common::tree_node::Transformed;
17use datafusion::common::Result;
18use datafusion::datasource::source_as_provider;
19use datafusion::logical_expr::expr::AggregateFunction;
20use datafusion::logical_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableScan};
21use datafusion::optimizer::optimizer::ApplyOrder;
22use datafusion::optimizer::OptimizerRule;
23use datafusion::prelude::lit;
24use std::sync::Arc;
25use tracing::{debug, info, trace, warn};
26
27#[derive(Debug, Default)]
29pub struct CountStatisticsRule;
30
31impl CountStatisticsRule {
32 pub fn new() -> Self {
33 Self
34 }
35}
36
37impl OptimizerRule for CountStatisticsRule {
38 fn name(&self) -> &str {
39 "count_statistics"
40 }
41
42 fn apply_order(&self) -> Option<ApplyOrder> {
43 Some(ApplyOrder::BottomUp)
44 }
45
46 fn supports_rewrite(&self) -> bool {
47 true
48 }
49
50 fn rewrite(
51 &self,
52 plan: LogicalPlan,
53 _config: &dyn datafusion::optimizer::OptimizerConfig,
54 ) -> Result<Transformed<LogicalPlan>> {
55 let LogicalPlan::Aggregate(aggregate) = &plan else {
57 trace!("Skipping non-Aggregate node");
58 return Ok(Transformed::no(plan));
59 };
60
61 debug!(
62 group_by_count = aggregate.group_expr.len(),
63 aggr_count = aggregate.aggr_expr.len(),
64 "Evaluating Aggregate node for count optimization"
65 );
66
67 if !aggregate.group_expr.is_empty() {
69 debug!(
70 group_by_count = aggregate.group_expr.len(),
71 "Skipping: has GROUP BY clause"
72 );
73 return Ok(Transformed::no(plan));
74 }
75
76 if aggregate.aggr_expr.is_empty() {
78 debug!("Skipping: no aggregate expressions");
79 return Ok(Transformed::no(plan));
80 }
81
82 let input = aggregate.input.as_ref();
84 let table_scan = match unwrap_to_table_scan(input) {
85 Some(scan) => scan,
86 None => {
87 debug!("Skipping: could not find TableScan in input");
88 return Ok(Transformed::no(plan));
89 }
90 };
91
92 let table_name = table_scan.table_name.to_string();
93 debug!(table = %table_name, "Found TableScan");
94
95 if !table_scan.filters.is_empty() {
97 debug!(
98 table = %table_name,
99 filter_count = table_scan.filters.len(),
100 "Skipping: query has WHERE filters"
101 );
102 return Ok(Transformed::no(plan));
103 }
104
105 let provider = match source_as_provider(&table_scan.source) {
107 Ok(p) => p,
108 Err(e) => {
109 warn!(
110 table = %table_name,
111 error = %e,
112 "Could not get TableProvider from TableSource"
113 );
114 return Ok(Transformed::no(plan));
115 }
116 };
117
118 let statistics = match provider.statistics() {
120 Some(stats) => stats,
121 None => {
122 debug!(
123 table = %table_name,
124 "Skipping: table does not provide statistics"
125 );
126 return Ok(Transformed::no(plan));
127 }
128 };
129
130 let num_rows: usize = match statistics.num_rows {
132 Precision::Exact(n) => {
133 debug!(table = %table_name, num_rows = n, "Got exact row count from statistics");
134 n
135 }
136 Precision::Inexact(n) => {
137 debug!(
138 table = %table_name,
139 approx_rows = n,
140 "Skipping: row count is inexact"
141 );
142 return Ok(Transformed::no(plan));
143 }
144 Precision::Absent => {
145 debug!(table = %table_name, "Skipping: row count not available");
146 return Ok(Transformed::no(plan));
147 }
148 };
149
150 let schema = provider.schema();
151
152 let mut count_values: Vec<(String, i64)> = Vec::new();
154
155 for aggr_expr in &aggregate.aggr_expr {
156 match extract_count_info(aggr_expr) {
157 Some((is_count_star, column_name)) => {
158 let alias_name = get_expr_alias(aggr_expr, &aggregate.schema);
160
161 let count_value = if is_count_star {
162 debug!(alias = %alias_name, "Found count(*)");
163 num_rows as i64
164 } else if let Some(ref col_name) = column_name {
165 let col_idx = schema.fields().iter().position(|f| f.name() == col_name);
167
168 if let Some(idx) = col_idx {
169 let col_stats = &statistics.column_statistics;
170 if idx < col_stats.len() {
171 match col_stats[idx].null_count {
172 Precision::Exact(0) => {
173 debug!(
174 alias = %alias_name,
175 column = %col_name,
176 "Column has no nulls, count = num_rows"
177 );
178 num_rows as i64
179 }
180 Precision::Exact(nulls) => {
181 let count = num_rows.saturating_sub(nulls) as i64;
182 debug!(
183 alias = %alias_name,
184 column = %col_name,
185 null_count = nulls,
186 count,
187 "Column has nulls, count = num_rows - null_count"
188 );
189 count
190 }
191 Precision::Inexact(nulls) => {
192 debug!(
193 column = %col_name,
194 approx_nulls = nulls,
195 "Skipping: null count is inexact"
196 );
197 return Ok(Transformed::no(plan));
198 }
199 Precision::Absent => {
200 debug!(
201 column = %col_name,
202 "Skipping: null count not available"
203 );
204 return Ok(Transformed::no(plan));
205 }
206 }
207 } else {
208 debug!(
209 column = %col_name,
210 "Skipping: column statistics not available"
211 );
212 return Ok(Transformed::no(plan));
213 }
214 } else {
215 debug!(column = %col_name, "Skipping: column not found in schema");
216 return Ok(Transformed::no(plan));
217 }
218 } else {
219 debug!("Skipping: column name not available");
220 return Ok(Transformed::no(plan));
221 };
222
223 count_values.push((alias_name, count_value));
224 }
225 None => {
226 debug!("Skipping: found non-count aggregate function");
228 return Ok(Transformed::no(plan));
229 }
230 }
231 }
232
233 info!(
234 table = %table_name,
235 count_expressions = count_values.len(),
236 values = ?count_values,
237 "Count optimization applied - replacing with constants"
238 );
239
240 create_multi_count_plan(&count_values)
241 }
242}
243
244fn extract_count_info(expr: &Expr) -> Option<(bool, Option<String>)> {
247 let inner = unwrap_alias(expr);
248 match inner {
249 Expr::AggregateFunction(AggregateFunction { func, params }) => {
250 if func.name() != "count" {
251 trace!(func_name = %func.name(), "Not a count function");
252 return None;
253 }
254 match params.args.first() {
256 Some(Expr::Literal(_, _)) => Some((true, None)), Some(Expr::Column(col)) => Some((false, Some(col.name.clone()))),
258 None => Some((true, None)), _ => {
260 debug!("Unsupported count argument type");
261 None
262 }
263 }
264 }
265 _ => {
266 trace!("Not an AggregateFunction expression");
267 None
268 }
269 }
270}
271
272fn get_expr_alias(expr: &Expr, _schema: &Arc<datafusion::common::DFSchema>) -> String {
274 if let Expr::Alias(alias) = expr {
276 return alias.name.clone();
277 }
278
279 expr.schema_name().to_string()
281}
282
283fn unwrap_alias(expr: &Expr) -> &Expr {
285 match expr {
286 Expr::Alias(alias) => unwrap_alias(&alias.expr),
287 other => other,
288 }
289}
290
291fn unwrap_to_table_scan(plan: &LogicalPlan) -> Option<&TableScan> {
294 match plan {
295 LogicalPlan::TableScan(scan) => Some(scan),
296 LogicalPlan::Projection(Projection { input, .. }) => unwrap_to_table_scan(input),
297 LogicalPlan::SubqueryAlias(alias) => unwrap_to_table_scan(&alias.input),
298 LogicalPlan::Filter(_) => {
299 debug!("Found Filter node - count would change after filtering");
301 None
302 }
303 other => {
304 debug!(node_type = %other.display(), "Unsupported node type in input");
305 None
306 }
307 }
308}
309
310fn create_multi_count_plan(count_values: &[(String, i64)]) -> Result<Transformed<LogicalPlan>> {
312 let exprs: Vec<Expr> = count_values
314 .iter()
315 .map(|(alias, value)| lit(*value).alias(alias))
316 .collect();
317
318 let empty = LogicalPlan::EmptyRelation(EmptyRelation {
320 produce_one_row: true,
321 schema: Arc::new(arrow::datatypes::Schema::empty().try_into()?),
322 });
323
324 let projection = LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(empty))?);
326
327 Ok(Transformed::yes(projection))
328}