sql_cli/query_plan/
order_by_alias_transformer.rs1use crate::query_plan::pipeline::ASTTransformer;
41use crate::sql::parser::ast::{ColumnRef, SelectItem, SelectStatement, SqlExpression};
42use anyhow::Result;
43use std::collections::HashMap;
44use tracing::debug;
45
46pub struct OrderByAliasTransformer {
48 alias_counter: usize,
50}
51
52impl OrderByAliasTransformer {
53 pub fn new() -> Self {
54 Self { alias_counter: 0 }
55 }
56
57 fn is_aggregate_function(expr: &SqlExpression) -> bool {
59 matches!(
60 expr,
61 SqlExpression::FunctionCall { name, .. }
62 if matches!(
63 name.to_uppercase().as_str(),
64 "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "COUNT_DISTINCT"
65 )
66 )
67 }
68
69 fn generate_alias(&mut self) -> String {
71 self.alias_counter += 1;
72 format!("__orderby_agg_{}", self.alias_counter)
73 }
74
75 fn normalize_aggregate_expr(expr: &SqlExpression) -> String {
81 match expr {
82 SqlExpression::FunctionCall { name, args, .. } => {
83 let args_str = args
84 .iter()
85 .map(|arg| match arg {
86 SqlExpression::Column(col_ref) => col_ref.name.to_uppercase(),
87 SqlExpression::StringLiteral(s) if s == "*" => "*".to_string(),
89 SqlExpression::StringLiteral(s) => format!("'{}'", s).to_uppercase(),
90 SqlExpression::NumberLiteral(n) => n.to_uppercase(),
91 _ => format!("{:?}", arg).to_uppercase(), })
93 .collect::<Vec<_>>()
94 .join(", ");
95 format!("{}({})", name.to_uppercase(), args_str)
96 }
97 _ => String::new(),
98 }
99 }
100
101 fn build_aggregate_map(
104 &mut self,
105 select_items: &mut Vec<SelectItem>,
106 ) -> HashMap<String, String> {
107 let mut aggregate_map = HashMap::new();
108
109 for item in select_items.iter_mut() {
110 if let SelectItem::Expression { expr, alias, .. } = item {
111 if Self::is_aggregate_function(expr) {
112 let normalized = Self::normalize_aggregate_expr(expr);
113
114 if alias.is_empty() {
116 *alias = self.generate_alias();
117 debug!(
118 "Generated alias '{}' for aggregate in ORDER BY: {}",
119 alias, normalized
120 );
121 }
122
123 debug!("Mapped aggregate '{}' to alias '{}'", normalized, alias);
124 aggregate_map.insert(normalized, alias.clone());
125 }
126 }
127 }
128
129 aggregate_map
130 }
131
132 fn expression_to_string(expr: &SqlExpression) -> String {
135 match expr {
136 SqlExpression::Column(col_ref) => col_ref.name.to_uppercase(),
137 SqlExpression::StringLiteral(s) if s == "*" => "*".to_string(),
139 SqlExpression::StringLiteral(s) => format!("'{}'", s),
140 SqlExpression::FunctionCall { name, args, .. } => {
141 let args_str = args
142 .iter()
143 .map(|arg| Self::expression_to_string(arg))
144 .collect::<Vec<_>>()
145 .join(", ");
146 format!("{}({})", name.to_uppercase(), args_str)
147 }
148 _ => "expr".to_string(), }
150 }
151
152 fn extract_aggregate_from_order_column(column_name: &str) -> Option<String> {
155 let upper = column_name.to_uppercase();
158
159 if (upper.starts_with("COUNT(") && upper.ends_with(')'))
160 || (upper.starts_with("SUM(") && upper.ends_with(')'))
161 || (upper.starts_with("AVG(") && upper.ends_with(')'))
162 || (upper.starts_with("MIN(") && upper.ends_with(')'))
163 || (upper.starts_with("MAX(") && upper.ends_with(')'))
164 || (upper.starts_with("COUNT_DISTINCT(") && upper.ends_with(')'))
165 {
166 Some(upper)
168 } else {
169 None
170 }
171 }
172}
173
174impl Default for OrderByAliasTransformer {
175 fn default() -> Self {
176 Self::new()
177 }
178}
179
180impl ASTTransformer for OrderByAliasTransformer {
181 fn name(&self) -> &str {
182 "OrderByAliasTransformer"
183 }
184
185 fn description(&self) -> &str {
186 "Rewrites ORDER BY aggregate expressions to use SELECT aliases"
187 }
188
189 fn transform(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
190 if stmt.order_by.is_none() {
192 return Ok(stmt);
193 }
194
195 let aggregate_map = self.build_aggregate_map(&mut stmt.select_items);
197
198 if aggregate_map.is_empty() {
199 return Ok(stmt);
201 }
202
203 if let Some(order_by) = stmt.order_by.as_mut() {
205 let mut modified = false;
206
207 for order_col in order_by.iter_mut() {
208 let expr_str = Self::expression_to_string(&order_col.expr);
210
211 if let Some(normalized) = Self::extract_aggregate_from_order_column(&expr_str) {
213 if let Some(alias) = aggregate_map.get(&normalized) {
215 debug!("Rewriting ORDER BY '{}' to use alias '{}'", expr_str, alias);
216 order_col.expr = SqlExpression::Column(ColumnRef::unquoted(alias.clone()));
218 modified = true;
219 }
220 }
221 }
222
223 if modified {
224 debug!(
225 "Rewrote ORDER BY to use {} aggregate alias(es)",
226 aggregate_map.len()
227 );
228 }
229 }
230
231 Ok(stmt)
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use crate::sql::parser::ast::{ColumnRef, QuoteStyle, SortDirection};
239
240 #[test]
241 fn test_extract_aggregate_from_order_column() {
242 assert_eq!(
243 OrderByAliasTransformer::extract_aggregate_from_order_column("SUM(sales_amount)"),
244 Some("SUM(SALES_AMOUNT)".to_string())
245 );
246
247 assert_eq!(
248 OrderByAliasTransformer::extract_aggregate_from_order_column("COUNT(*)"),
249 Some("COUNT(*)".to_string())
250 );
251
252 assert_eq!(
253 OrderByAliasTransformer::extract_aggregate_from_order_column("region"),
254 None
255 );
256
257 assert_eq!(
258 OrderByAliasTransformer::extract_aggregate_from_order_column("total"),
259 None
260 );
261 }
262
263 #[test]
264 fn test_normalize_aggregate_expr() {
265 let expr = SqlExpression::FunctionCall {
266 name: "SUM".to_string(),
267 args: vec![SqlExpression::Column(ColumnRef {
268 name: "sales_amount".to_string(),
269 quote_style: QuoteStyle::None,
270 table_prefix: None,
271 })],
272 distinct: false,
273 };
274
275 assert_eq!(
276 OrderByAliasTransformer::normalize_aggregate_expr(&expr),
277 "SUM(SALES_AMOUNT)"
278 );
279 }
280
281 #[test]
282 fn test_is_aggregate_function() {
283 let sum_expr = SqlExpression::FunctionCall {
284 name: "SUM".to_string(),
285 args: vec![],
286 distinct: false,
287 };
288 assert!(OrderByAliasTransformer::is_aggregate_function(&sum_expr));
289
290 let upper_expr = SqlExpression::FunctionCall {
291 name: "UPPER".to_string(),
292 args: vec![],
293 distinct: false,
294 };
295 assert!(!OrderByAliasTransformer::is_aggregate_function(&upper_expr));
296 }
297}