sql_cli/query_plan/
order_by_alias_transformer.rs1use crate::query_plan::pipeline::ASTTransformer;
41use crate::sql::parser::ast::{
42 CTEType, ColumnRef, SelectItem, SelectStatement, SqlExpression, TableSource,
43};
44use anyhow::Result;
45use std::collections::HashMap;
46use tracing::debug;
47
48pub struct OrderByAliasTransformer {
50 alias_counter: usize,
52}
53
54impl OrderByAliasTransformer {
55 pub fn new() -> Self {
56 Self { alias_counter: 0 }
57 }
58
59 fn is_aggregate_function(expr: &SqlExpression) -> bool {
61 matches!(
62 expr,
63 SqlExpression::FunctionCall { name, .. }
64 if matches!(
65 name.to_uppercase().as_str(),
66 "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "COUNT_DISTINCT"
67 )
68 )
69 }
70
71 fn generate_alias(&mut self) -> String {
73 self.alias_counter += 1;
74 format!("__orderby_agg_{}", self.alias_counter)
75 }
76
77 fn normalize_aggregate_expr(expr: &SqlExpression) -> String {
83 match expr {
84 SqlExpression::FunctionCall { name, args, .. } => {
85 let args_str = args
86 .iter()
87 .map(|arg| match arg {
88 SqlExpression::Column(col_ref) => col_ref.name.to_uppercase(),
89 SqlExpression::StringLiteral(s) if s == "*" => "*".to_string(),
91 SqlExpression::StringLiteral(s) => format!("'{}'", s).to_uppercase(),
92 SqlExpression::NumberLiteral(n) => n.to_uppercase(),
93 _ => format!("{:?}", arg).to_uppercase(), })
95 .collect::<Vec<_>>()
96 .join(", ");
97 format!("{}({})", name.to_uppercase(), args_str)
98 }
99 _ => String::new(),
100 }
101 }
102
103 fn build_aggregate_map(
106 &mut self,
107 select_items: &mut Vec<SelectItem>,
108 ) -> HashMap<String, String> {
109 let mut aggregate_map = HashMap::new();
110
111 for item in select_items.iter_mut() {
112 if let SelectItem::Expression { expr, alias, .. } = item {
113 if Self::is_aggregate_function(expr) {
114 let normalized = Self::normalize_aggregate_expr(expr);
115
116 if alias.is_empty() {
118 *alias = self.generate_alias();
119 debug!(
120 "Generated alias '{}' for aggregate in ORDER BY: {}",
121 alias, normalized
122 );
123 }
124
125 debug!("Mapped aggregate '{}' to alias '{}'", normalized, alias);
126 aggregate_map.insert(normalized, alias.clone());
127 }
128 }
129 }
130
131 aggregate_map
132 }
133
134 fn expression_to_string(expr: &SqlExpression) -> String {
137 match expr {
138 SqlExpression::Column(col_ref) => col_ref.name.to_uppercase(),
139 SqlExpression::StringLiteral(s) if s == "*" => "*".to_string(),
141 SqlExpression::StringLiteral(s) => format!("'{}'", s),
142 SqlExpression::FunctionCall { name, args, .. } => {
143 let args_str = args
144 .iter()
145 .map(|arg| Self::expression_to_string(arg))
146 .collect::<Vec<_>>()
147 .join(", ");
148 format!("{}({})", name.to_uppercase(), args_str)
149 }
150 _ => "expr".to_string(), }
152 }
153
154 fn extract_aggregate_from_order_column(column_name: &str) -> Option<String> {
157 let upper = column_name.to_uppercase();
160
161 if (upper.starts_with("COUNT(") && upper.ends_with(')'))
162 || (upper.starts_with("SUM(") && upper.ends_with(')'))
163 || (upper.starts_with("AVG(") && upper.ends_with(')'))
164 || (upper.starts_with("MIN(") && upper.ends_with(')'))
165 || (upper.starts_with("MAX(") && upper.ends_with(')'))
166 || (upper.starts_with("COUNT_DISTINCT(") && upper.ends_with(')'))
167 {
168 Some(upper)
170 } else {
171 None
172 }
173 }
174}
175
176impl Default for OrderByAliasTransformer {
177 fn default() -> Self {
178 Self::new()
179 }
180}
181
182impl ASTTransformer for OrderByAliasTransformer {
183 fn name(&self) -> &str {
184 "OrderByAliasTransformer"
185 }
186
187 fn description(&self) -> &str {
188 "Rewrites ORDER BY aggregate expressions to use SELECT aliases"
189 }
190
191 fn transform(&mut self, stmt: SelectStatement) -> Result<SelectStatement> {
192 self.transform_statement(stmt)
193 }
194}
195
196impl OrderByAliasTransformer {
197 #[allow(deprecated)]
201 fn transform_statement(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
202 for cte in stmt.ctes.iter_mut() {
204 if let CTEType::Standard(ref mut inner) = cte.cte_type {
205 let taken = std::mem::take(inner);
206 *inner = self.transform_statement(taken)?;
207 }
208 }
209
210 if let Some(TableSource::DerivedTable { query, .. }) = stmt.from_source.as_mut() {
212 let taken = std::mem::take(query.as_mut());
213 **query = self.transform_statement(taken)?;
214 }
215
216 if let Some(subq) = stmt.from_subquery.as_mut() {
218 let taken = std::mem::take(subq.as_mut());
219 **subq = self.transform_statement(taken)?;
220 }
221
222 for (_op, rhs) in stmt.set_operations.iter_mut() {
224 let taken = std::mem::take(rhs.as_mut());
225 **rhs = self.transform_statement(taken)?;
226 }
227
228 self.apply_rewrite(&mut stmt);
230
231 Ok(stmt)
232 }
233
234 fn apply_rewrite(&mut self, stmt: &mut SelectStatement) {
236 if stmt.order_by.is_none() {
237 return;
238 }
239
240 let aggregate_map = self.build_aggregate_map(&mut stmt.select_items);
242
243 if aggregate_map.is_empty() {
244 return;
245 }
246
247 if let Some(order_by) = stmt.order_by.as_mut() {
249 let mut modified = false;
250
251 for order_col in order_by.iter_mut() {
252 let expr_str = Self::expression_to_string(&order_col.expr);
253
254 if let Some(normalized) = Self::extract_aggregate_from_order_column(&expr_str) {
255 if let Some(alias) = aggregate_map.get(&normalized) {
256 debug!("Rewriting ORDER BY '{}' to use alias '{}'", expr_str, alias);
257 order_col.expr = SqlExpression::Column(ColumnRef::unquoted(alias.clone()));
258 modified = true;
259 }
260 }
261 }
262
263 if modified {
264 debug!(
265 "Rewrote ORDER BY to use {} aggregate alias(es)",
266 aggregate_map.len()
267 );
268 }
269 }
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use crate::sql::parser::ast::{ColumnRef, QuoteStyle, SortDirection};
277
278 #[test]
279 fn test_extract_aggregate_from_order_column() {
280 assert_eq!(
281 OrderByAliasTransformer::extract_aggregate_from_order_column("SUM(sales_amount)"),
282 Some("SUM(SALES_AMOUNT)".to_string())
283 );
284
285 assert_eq!(
286 OrderByAliasTransformer::extract_aggregate_from_order_column("COUNT(*)"),
287 Some("COUNT(*)".to_string())
288 );
289
290 assert_eq!(
291 OrderByAliasTransformer::extract_aggregate_from_order_column("region"),
292 None
293 );
294
295 assert_eq!(
296 OrderByAliasTransformer::extract_aggregate_from_order_column("total"),
297 None
298 );
299 }
300
301 #[test]
302 fn test_normalize_aggregate_expr() {
303 let expr = SqlExpression::FunctionCall {
304 name: "SUM".to_string(),
305 args: vec![SqlExpression::Column(ColumnRef {
306 name: "sales_amount".to_string(),
307 quote_style: QuoteStyle::None,
308 table_prefix: None,
309 })],
310 distinct: false,
311 };
312
313 assert_eq!(
314 OrderByAliasTransformer::normalize_aggregate_expr(&expr),
315 "SUM(SALES_AMOUNT)"
316 );
317 }
318
319 #[test]
320 fn test_is_aggregate_function() {
321 let sum_expr = SqlExpression::FunctionCall {
322 name: "SUM".to_string(),
323 args: vec![],
324 distinct: false,
325 };
326 assert!(OrderByAliasTransformer::is_aggregate_function(&sum_expr));
327
328 let upper_expr = SqlExpression::FunctionCall {
329 name: "UPPER".to_string(),
330 args: vec![],
331 distinct: false,
332 };
333 assert!(!OrderByAliasTransformer::is_aggregate_function(&upper_expr));
334 }
335}