1use crate::query_plan::pipeline::ASTTransformer;
36use crate::sql::parser::ast::{CTEType, SelectItem, SelectStatement, SqlExpression, TableSource};
37use anyhow::Result;
38use std::collections::HashMap;
39use tracing::debug;
40
41pub struct GroupByAliasExpander {
43 expansions: usize,
45}
46
47impl GroupByAliasExpander {
48 pub fn new() -> Self {
49 Self { expansions: 0 }
50 }
51
52 fn extract_aliases(select_items: &[SelectItem]) -> HashMap<String, SqlExpression> {
55 let mut aliases = HashMap::new();
56
57 for item in select_items {
58 if let SelectItem::Expression { expr, alias, .. } = item {
59 if !alias.is_empty() {
60 aliases.insert(alias.clone(), expr.clone());
61 debug!("Found SELECT alias: {} -> {:?}", alias, expr);
62 }
63 }
64 }
65
66 aliases
67 }
68
69 fn expand_expression(
72 expr: &SqlExpression,
73 aliases: &HashMap<String, SqlExpression>,
74 ) -> (SqlExpression, bool) {
75 match expr {
76 SqlExpression::Column(col_ref) => {
78 if col_ref.table_prefix.is_none() {
80 if let Some(alias_expr) = aliases.get(&col_ref.name) {
81 debug!(
82 "Expanding alias '{}' in GROUP BY to: {:?}",
83 col_ref.name, alias_expr
84 );
85 return (alias_expr.clone(), true);
86 }
87 }
88 (expr.clone(), false)
89 }
90
91 _ => (expr.clone(), false),
94 }
95 }
96
97 fn expand_group_by(
99 &mut self,
100 group_by: &mut Vec<SqlExpression>,
101 aliases: &HashMap<String, SqlExpression>,
102 ) -> bool {
103 let mut any_expanded = false;
104
105 for expr in group_by.iter_mut() {
106 let (new_expr, expanded) = Self::expand_expression(expr, aliases);
107 if expanded {
108 *expr = new_expr;
109 any_expanded = true;
110 self.expansions += 1;
111 }
112 }
113
114 any_expanded
115 }
116
117 #[allow(deprecated)]
121 fn transform_statement(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
122 for cte in stmt.ctes.iter_mut() {
124 if let CTEType::Standard(ref mut inner) = cte.cte_type {
125 let taken = std::mem::take(inner);
126 *inner = self.transform_statement(taken)?;
127 }
128 }
129
130 if let Some(TableSource::DerivedTable { query, .. }) = stmt.from_source.as_mut() {
132 let taken = std::mem::take(query.as_mut());
133 **query = self.transform_statement(taken)?;
134 }
135
136 if let Some(subq) = stmt.from_subquery.as_mut() {
138 let taken = std::mem::take(subq.as_mut());
139 **subq = self.transform_statement(taken)?;
140 }
141
142 for (_op, rhs) in stmt.set_operations.iter_mut() {
144 let taken = std::mem::take(rhs.as_mut());
145 **rhs = self.transform_statement(taken)?;
146 }
147
148 self.apply_expansion(&mut stmt);
150
151 Ok(stmt)
152 }
153
154 fn apply_expansion(&mut self, stmt: &mut SelectStatement) {
156 if stmt.group_by.is_none() {
157 return;
158 }
159
160 let aliases = Self::extract_aliases(&stmt.select_items);
161 if aliases.is_empty() {
162 return;
163 }
164
165 if let Some(ref mut group_by) = stmt.group_by {
166 let expanded = self.expand_group_by(group_by, &aliases);
167 if expanded {
168 debug!(
169 "Expanded {} alias reference(s) in GROUP BY clause",
170 self.expansions
171 );
172 }
173 }
174 }
175}
176
177impl Default for GroupByAliasExpander {
178 fn default() -> Self {
179 Self::new()
180 }
181}
182
183impl ASTTransformer for GroupByAliasExpander {
184 fn name(&self) -> &str {
185 "GroupByAliasExpander"
186 }
187
188 fn description(&self) -> &str {
189 "Expands SELECT aliases in GROUP BY clauses to their full expressions"
190 }
191
192 fn transform(&mut self, stmt: SelectStatement) -> Result<SelectStatement> {
193 self.transform_statement(stmt)
194 }
195
196 fn begin(&mut self) -> Result<()> {
197 self.expansions = 0;
199 Ok(())
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use crate::sql::parser::ast::{ColumnRef, QuoteStyle};
207
208 #[test]
209 fn test_extract_aliases() {
210 let grp_expr = SqlExpression::BinaryOp {
211 left: Box::new(SqlExpression::Column(ColumnRef {
212 name: "id".to_string(),
213 quote_style: QuoteStyle::None,
214 table_prefix: None,
215 })),
216 op: "%".to_string(),
217 right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
218 };
219
220 let select_items = vec![SelectItem::Expression {
221 expr: grp_expr.clone(),
222 alias: "grp".to_string(),
223 leading_comments: vec![],
224 trailing_comment: None,
225 }];
226
227 let aliases = GroupByAliasExpander::extract_aliases(&select_items);
228 assert_eq!(aliases.len(), 1);
229 assert!(aliases.contains_key("grp"));
230 }
231
232 #[test]
233 fn test_expand_simple_column_reference() {
234 let aliases = HashMap::from([(
235 "grp".to_string(),
236 SqlExpression::BinaryOp {
237 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("id".to_string()))),
238 op: "%".to_string(),
239 right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
240 },
241 )]);
242
243 let expr = SqlExpression::Column(ColumnRef::unquoted("grp".to_string()));
244 let (expanded, changed) = GroupByAliasExpander::expand_expression(&expr, &aliases);
245
246 assert!(changed);
247 assert!(matches!(expanded, SqlExpression::BinaryOp { .. }));
248 }
249
250 #[test]
251 fn test_does_not_expand_full_expressions() {
252 let aliases = HashMap::from([(
253 "grp".to_string(),
254 SqlExpression::BinaryOp {
255 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("id".to_string()))),
256 op: "%".to_string(),
257 right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
258 },
259 )]);
260
261 let expr = SqlExpression::BinaryOp {
263 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("id".to_string()))),
264 op: "%".to_string(),
265 right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
266 };
267
268 let (expanded, changed) = GroupByAliasExpander::expand_expression(&expr, &aliases);
269
270 assert!(!changed);
271 assert!(matches!(expanded, SqlExpression::BinaryOp { .. }));
272 }
273
274 #[test]
275 fn test_transform_with_no_group_by() {
276 let mut transformer = GroupByAliasExpander::new();
277 let stmt = SelectStatement {
278 group_by: None,
279 ..Default::default()
280 };
281
282 let result = transformer.transform(stmt);
283 assert!(result.is_ok());
284 }
285
286 #[test]
287 fn test_transform_expands_alias() {
288 let mut transformer = GroupByAliasExpander::new();
289
290 let grp_expr = SqlExpression::BinaryOp {
291 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("id".to_string()))),
292 op: "%".to_string(),
293 right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
294 };
295
296 let stmt = SelectStatement {
297 select_items: vec![SelectItem::Expression {
298 expr: grp_expr.clone(),
299 alias: "grp".to_string(),
300 leading_comments: vec![],
301 trailing_comment: None,
302 }],
303 group_by: Some(vec![SqlExpression::Column(ColumnRef::unquoted(
304 "grp".to_string(),
305 ))]),
306 ..Default::default()
307 };
308
309 let result = transformer.transform(stmt).unwrap();
310
311 if let Some(group_by) = &result.group_by {
313 assert_eq!(group_by.len(), 1);
314 assert!(matches!(group_by[0], SqlExpression::BinaryOp { .. }));
316 } else {
317 panic!("Expected GROUP BY clause");
318 }
319
320 assert_eq!(transformer.expansions, 1);
321 }
322
323 #[test]
324 fn test_does_not_expand_table_prefixed_columns() {
325 let aliases = HashMap::from([(
326 "grp".to_string(),
327 SqlExpression::BinaryOp {
328 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("id".to_string()))),
329 op: "%".to_string(),
330 right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
331 },
332 )]);
333
334 let expr = SqlExpression::Column(ColumnRef {
336 name: "grp".to_string(),
337 quote_style: QuoteStyle::None,
338 table_prefix: Some("t".to_string()),
339 });
340
341 let (expanded, changed) = GroupByAliasExpander::expand_expression(&expr, &aliases);
342
343 assert!(!changed);
344 assert!(matches!(expanded, SqlExpression::Column(_)));
345 }
346
347 #[test]
348 fn test_multiple_aliases_in_group_by() {
349 let mut transformer = GroupByAliasExpander::new();
350
351 let year_expr = SqlExpression::FunctionCall {
352 name: "YEAR".to_string(),
353 args: vec![SqlExpression::Column(ColumnRef::unquoted(
354 "date".to_string(),
355 ))],
356 distinct: false,
357 };
358
359 let month_expr = SqlExpression::FunctionCall {
360 name: "MONTH".to_string(),
361 args: vec![SqlExpression::Column(ColumnRef::unquoted(
362 "date".to_string(),
363 ))],
364 distinct: false,
365 };
366
367 let stmt = SelectStatement {
368 select_items: vec![
369 SelectItem::Expression {
370 expr: year_expr.clone(),
371 alias: "yr".to_string(),
372 leading_comments: vec![],
373 trailing_comment: None,
374 },
375 SelectItem::Expression {
376 expr: month_expr.clone(),
377 alias: "mon".to_string(),
378 leading_comments: vec![],
379 trailing_comment: None,
380 },
381 ],
382 group_by: Some(vec![
383 SqlExpression::Column(ColumnRef::unquoted("yr".to_string())),
384 SqlExpression::Column(ColumnRef::unquoted("mon".to_string())),
385 ]),
386 ..Default::default()
387 };
388
389 let result = transformer.transform(stmt).unwrap();
390
391 if let Some(group_by) = &result.group_by {
393 assert_eq!(group_by.len(), 2);
394 assert!(matches!(group_by[0], SqlExpression::FunctionCall { .. }));
395 assert!(matches!(group_by[1], SqlExpression::FunctionCall { .. }));
396 } else {
397 panic!("Expected GROUP BY clause");
398 }
399
400 assert_eq!(transformer.expansions, 2);
401 }
402}