1use crate::query_plan::pipeline::ASTTransformer;
41use crate::sql::parser::ast::{
42 CTEType, ColumnRef, SelectItem, SelectStatement, SqlExpression, TableSource,
43};
44use anyhow::Result;
45use std::collections::{HashMap, HashSet};
46use tracing::debug;
47
48pub const HIDDEN_ORDERBY_PREFIX: &str = "__hidden_orderby_";
52
53pub struct OrderByAliasTransformer {
55 alias_counter: usize,
57 hidden_counter: usize,
59}
60
61impl OrderByAliasTransformer {
62 pub fn new() -> Self {
63 Self {
64 alias_counter: 0,
65 hidden_counter: 0,
66 }
67 }
68
69 fn is_aggregate_function(expr: &SqlExpression) -> bool {
71 matches!(
72 expr,
73 SqlExpression::FunctionCall { name, .. }
74 if matches!(
75 name.to_uppercase().as_str(),
76 "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "COUNT_DISTINCT"
77 )
78 )
79 }
80
81 fn generate_alias(&mut self) -> String {
83 self.alias_counter += 1;
84 format!("__orderby_agg_{}", self.alias_counter)
85 }
86
87 fn normalize_aggregate_expr(expr: &SqlExpression) -> String {
93 match expr {
94 SqlExpression::FunctionCall { name, args, .. } => {
95 let args_str = args
96 .iter()
97 .map(|arg| match arg {
98 SqlExpression::Column(col_ref) => col_ref.name.to_uppercase(),
99 SqlExpression::StringLiteral(s) if s == "*" => "*".to_string(),
101 SqlExpression::StringLiteral(s) => format!("'{}'", s).to_uppercase(),
102 SqlExpression::NumberLiteral(n) => n.to_uppercase(),
103 _ => format!("{:?}", arg).to_uppercase(), })
105 .collect::<Vec<_>>()
106 .join(", ");
107 format!("{}({})", name.to_uppercase(), args_str)
108 }
109 _ => String::new(),
110 }
111 }
112
113 fn build_aggregate_map(
116 &mut self,
117 select_items: &mut Vec<SelectItem>,
118 ) -> HashMap<String, String> {
119 let mut aggregate_map = HashMap::new();
120
121 for item in select_items.iter_mut() {
122 if let SelectItem::Expression { expr, alias, .. } = item {
123 if Self::is_aggregate_function(expr) {
124 let normalized = Self::normalize_aggregate_expr(expr);
125
126 if alias.is_empty() {
128 *alias = self.generate_alias();
129 debug!(
130 "Generated alias '{}' for aggregate in ORDER BY: {}",
131 alias, normalized
132 );
133 }
134
135 debug!("Mapped aggregate '{}' to alias '{}'", normalized, alias);
136 aggregate_map.insert(normalized, alias.clone());
137 }
138 }
139 }
140
141 aggregate_map
142 }
143
144 fn expression_to_string(expr: &SqlExpression) -> String {
147 match expr {
148 SqlExpression::Column(col_ref) => col_ref.name.to_uppercase(),
149 SqlExpression::StringLiteral(s) if s == "*" => "*".to_string(),
151 SqlExpression::StringLiteral(s) => format!("'{}'", s),
152 SqlExpression::FunctionCall { name, args, .. } => {
153 let args_str = args
154 .iter()
155 .map(|arg| Self::expression_to_string(arg))
156 .collect::<Vec<_>>()
157 .join(", ");
158 format!("{}({})", name.to_uppercase(), args_str)
159 }
160 _ => "expr".to_string(), }
162 }
163
164 fn extract_aggregate_from_order_column(column_name: &str) -> Option<String> {
167 let upper = column_name.to_uppercase();
170
171 if (upper.starts_with("COUNT(") && upper.ends_with(')'))
172 || (upper.starts_with("SUM(") && upper.ends_with(')'))
173 || (upper.starts_with("AVG(") && upper.ends_with(')'))
174 || (upper.starts_with("MIN(") && upper.ends_with(')'))
175 || (upper.starts_with("MAX(") && upper.ends_with(')'))
176 || (upper.starts_with("COUNT_DISTINCT(") && upper.ends_with(')'))
177 {
178 Some(upper)
180 } else {
181 None
182 }
183 }
184}
185
186impl Default for OrderByAliasTransformer {
187 fn default() -> Self {
188 Self::new()
189 }
190}
191
192impl ASTTransformer for OrderByAliasTransformer {
193 fn name(&self) -> &str {
194 "OrderByAliasTransformer"
195 }
196
197 fn description(&self) -> &str {
198 "Rewrites ORDER BY aggregate expressions to use SELECT aliases"
199 }
200
201 fn transform(&mut self, stmt: SelectStatement) -> Result<SelectStatement> {
202 self.transform_statement(stmt)
203 }
204}
205
206impl OrderByAliasTransformer {
207 #[allow(deprecated)]
211 fn transform_statement(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
212 for cte in stmt.ctes.iter_mut() {
214 if let CTEType::Standard(ref mut inner) = cte.cte_type {
215 let taken = std::mem::take(inner);
216 *inner = self.transform_statement(taken)?;
217 }
218 }
219
220 if let Some(TableSource::DerivedTable { query, .. }) = stmt.from_source.as_mut() {
222 let taken = std::mem::take(query.as_mut());
223 **query = self.transform_statement(taken)?;
224 }
225
226 if let Some(subq) = stmt.from_subquery.as_mut() {
228 let taken = std::mem::take(subq.as_mut());
229 **subq = self.transform_statement(taken)?;
230 }
231
232 for (_op, rhs) in stmt.set_operations.iter_mut() {
234 let taken = std::mem::take(rhs.as_mut());
235 **rhs = self.transform_statement(taken)?;
236 }
237
238 self.apply_rewrite(&mut stmt);
240
241 Ok(stmt)
242 }
243
244 fn apply_rewrite(&mut self, stmt: &mut SelectStatement) {
246 if stmt.order_by.is_none() {
247 return;
248 }
249
250 let aggregate_map = self.build_aggregate_map(&mut stmt.select_items);
252
253 if !aggregate_map.is_empty() {
255 if let Some(order_by) = stmt.order_by.as_mut() {
256 let mut modified = false;
257
258 for order_col in order_by.iter_mut() {
259 let expr_str = Self::expression_to_string(&order_col.expr);
260
261 if let Some(normalized) = Self::extract_aggregate_from_order_column(&expr_str) {
262 if let Some(alias) = aggregate_map.get(&normalized) {
263 debug!("Rewriting ORDER BY '{}' to use alias '{}'", expr_str, alias);
264 order_col.expr =
265 SqlExpression::Column(ColumnRef::unquoted(alias.clone()));
266 modified = true;
267 }
268 }
269 }
270
271 if modified {
272 debug!(
273 "Rewrote ORDER BY to use {} aggregate alias(es)",
274 aggregate_map.len()
275 );
276 }
277 }
278 }
279
280 self.promote_hidden_order_by_columns(stmt);
285 }
286
287 fn promote_hidden_order_by_columns(&mut self, stmt: &mut SelectStatement) {
292 let order_by = match stmt.order_by.as_mut() {
293 Some(o) if !o.is_empty() => o,
294 _ => return,
295 };
296
297 if stmt
299 .select_items
300 .iter()
301 .any(|i| matches!(i, SelectItem::Star { .. } | SelectItem::StarExclude { .. }))
302 {
303 return;
304 }
305
306 let mut visible: HashSet<String> = HashSet::new();
310 for item in stmt.select_items.iter() {
311 match item {
312 SelectItem::Column { column, .. } => {
313 visible.insert(column.name.to_lowercase());
314 }
315 SelectItem::Expression { alias, .. } if !alias.is_empty() => {
316 visible.insert(alias.to_lowercase());
317 }
318 _ => {}
319 }
320 }
321
322 let mut promoted_columns: HashMap<String, String> = HashMap::new();
325 let mut promoted_exprs: HashMap<String, String> = HashMap::new();
326
327 for order_col in order_by.iter_mut() {
328 if let SqlExpression::Column(c) = &order_col.expr {
331 if visible.contains(&c.name.to_lowercase()) {
332 continue;
333 }
334 }
335
336 let expr_to_promote = order_col.expr.clone();
342 let (dedup_key, is_column) = match &expr_to_promote {
343 SqlExpression::Column(c) => (c.name.to_lowercase(), true),
344 other => (format!("{:?}", other), false),
345 };
346
347 let existing_alias = if is_column {
348 promoted_columns.get(&dedup_key).cloned()
349 } else {
350 promoted_exprs.get(&dedup_key).cloned()
351 };
352
353 let hidden_alias = if let Some(alias) = existing_alias {
354 alias
355 } else {
356 self.hidden_counter += 1;
357 let alias = format!("{}{}", HIDDEN_ORDERBY_PREFIX, self.hidden_counter);
358 debug!(
359 "Promoting ORDER BY expression as hidden SELECT item '{}': {:?}",
360 alias, expr_to_promote
361 );
362 stmt.select_items.push(SelectItem::Expression {
363 expr: expr_to_promote,
364 alias: alias.clone(),
365 leading_comments: Vec::new(),
366 trailing_comment: None,
367 });
368 if is_column {
369 promoted_columns.insert(dedup_key, alias.clone());
370 } else {
371 promoted_exprs.insert(dedup_key, alias.clone());
372 }
373 alias
374 };
375
376 order_col.expr = SqlExpression::Column(ColumnRef::unquoted(hidden_alias));
377 }
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384 use crate::sql::parser::ast::{ColumnRef, QuoteStyle, SortDirection};
385
386 #[test]
387 fn test_extract_aggregate_from_order_column() {
388 assert_eq!(
389 OrderByAliasTransformer::extract_aggregate_from_order_column("SUM(sales_amount)"),
390 Some("SUM(SALES_AMOUNT)".to_string())
391 );
392
393 assert_eq!(
394 OrderByAliasTransformer::extract_aggregate_from_order_column("COUNT(*)"),
395 Some("COUNT(*)".to_string())
396 );
397
398 assert_eq!(
399 OrderByAliasTransformer::extract_aggregate_from_order_column("region"),
400 None
401 );
402
403 assert_eq!(
404 OrderByAliasTransformer::extract_aggregate_from_order_column("total"),
405 None
406 );
407 }
408
409 #[test]
410 fn test_normalize_aggregate_expr() {
411 let expr = SqlExpression::FunctionCall {
412 name: "SUM".to_string(),
413 args: vec![SqlExpression::Column(ColumnRef {
414 name: "sales_amount".to_string(),
415 quote_style: QuoteStyle::None,
416 table_prefix: None,
417 })],
418 distinct: false,
419 };
420
421 assert_eq!(
422 OrderByAliasTransformer::normalize_aggregate_expr(&expr),
423 "SUM(SALES_AMOUNT)"
424 );
425 }
426
427 #[test]
428 fn test_is_aggregate_function() {
429 let sum_expr = SqlExpression::FunctionCall {
430 name: "SUM".to_string(),
431 args: vec![],
432 distinct: false,
433 };
434 assert!(OrderByAliasTransformer::is_aggregate_function(&sum_expr));
435
436 let upper_expr = SqlExpression::FunctionCall {
437 name: "UPPER".to_string(),
438 args: vec![],
439 distinct: false,
440 };
441 assert!(!OrderByAliasTransformer::is_aggregate_function(&upper_expr));
442 }
443}