1use crate::query_plan::pipeline::ASTTransformer;
36use crate::sql::parser::ast::{SelectItem, SelectStatement, SqlExpression};
37use anyhow::Result;
38use std::collections::HashMap;
39use tracing::debug;
40
41pub struct GroupByAliasExpander {
43 expansions: usize,
45}
46
47impl GroupByAliasExpander {
48 pub fn new() -> Self {
49 Self { expansions: 0 }
50 }
51
52 fn extract_aliases(select_items: &[SelectItem]) -> HashMap<String, SqlExpression> {
55 let mut aliases = HashMap::new();
56
57 for item in select_items {
58 if let SelectItem::Expression { expr, alias, .. } = item {
59 if !alias.is_empty() {
60 aliases.insert(alias.clone(), expr.clone());
61 debug!("Found SELECT alias: {} -> {:?}", alias, expr);
62 }
63 }
64 }
65
66 aliases
67 }
68
69 fn expand_expression(
72 expr: &SqlExpression,
73 aliases: &HashMap<String, SqlExpression>,
74 ) -> (SqlExpression, bool) {
75 match expr {
76 SqlExpression::Column(col_ref) => {
78 if col_ref.table_prefix.is_none() {
80 if let Some(alias_expr) = aliases.get(&col_ref.name) {
81 debug!(
82 "Expanding alias '{}' in GROUP BY to: {:?}",
83 col_ref.name, alias_expr
84 );
85 return (alias_expr.clone(), true);
86 }
87 }
88 (expr.clone(), false)
89 }
90
91 _ => (expr.clone(), false),
94 }
95 }
96
97 fn expand_group_by(
99 &mut self,
100 group_by: &mut Vec<SqlExpression>,
101 aliases: &HashMap<String, SqlExpression>,
102 ) -> bool {
103 let mut any_expanded = false;
104
105 for expr in group_by.iter_mut() {
106 let (new_expr, expanded) = Self::expand_expression(expr, aliases);
107 if expanded {
108 *expr = new_expr;
109 any_expanded = true;
110 self.expansions += 1;
111 }
112 }
113
114 any_expanded
115 }
116}
117
118impl Default for GroupByAliasExpander {
119 fn default() -> Self {
120 Self::new()
121 }
122}
123
124impl ASTTransformer for GroupByAliasExpander {
125 fn name(&self) -> &str {
126 "GroupByAliasExpander"
127 }
128
129 fn description(&self) -> &str {
130 "Expands SELECT aliases in GROUP BY clauses to their full expressions"
131 }
132
133 fn transform(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
134 if stmt.group_by.is_none() {
136 return Ok(stmt);
137 }
138
139 let aliases = Self::extract_aliases(&stmt.select_items);
141
142 if aliases.is_empty() {
143 return Ok(stmt);
145 }
146
147 if let Some(ref mut group_by) = stmt.group_by {
149 let expanded = self.expand_group_by(group_by, &aliases);
150 if expanded {
151 debug!(
152 "Expanded {} alias reference(s) in GROUP BY clause",
153 self.expansions
154 );
155 }
156 }
157
158 Ok(stmt)
159 }
160
161 fn begin(&mut self) -> Result<()> {
162 self.expansions = 0;
164 Ok(())
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use crate::sql::parser::ast::{ColumnRef, QuoteStyle};
172
173 #[test]
174 fn test_extract_aliases() {
175 let grp_expr = SqlExpression::BinaryOp {
176 left: Box::new(SqlExpression::Column(ColumnRef {
177 name: "id".to_string(),
178 quote_style: QuoteStyle::None,
179 table_prefix: None,
180 })),
181 op: "%".to_string(),
182 right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
183 };
184
185 let select_items = vec![SelectItem::Expression {
186 expr: grp_expr.clone(),
187 alias: "grp".to_string(),
188 leading_comments: vec![],
189 trailing_comment: None,
190 }];
191
192 let aliases = GroupByAliasExpander::extract_aliases(&select_items);
193 assert_eq!(aliases.len(), 1);
194 assert!(aliases.contains_key("grp"));
195 }
196
197 #[test]
198 fn test_expand_simple_column_reference() {
199 let aliases = HashMap::from([(
200 "grp".to_string(),
201 SqlExpression::BinaryOp {
202 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("id".to_string()))),
203 op: "%".to_string(),
204 right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
205 },
206 )]);
207
208 let expr = SqlExpression::Column(ColumnRef::unquoted("grp".to_string()));
209 let (expanded, changed) = GroupByAliasExpander::expand_expression(&expr, &aliases);
210
211 assert!(changed);
212 assert!(matches!(expanded, SqlExpression::BinaryOp { .. }));
213 }
214
215 #[test]
216 fn test_does_not_expand_full_expressions() {
217 let aliases = HashMap::from([(
218 "grp".to_string(),
219 SqlExpression::BinaryOp {
220 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("id".to_string()))),
221 op: "%".to_string(),
222 right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
223 },
224 )]);
225
226 let expr = SqlExpression::BinaryOp {
228 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("id".to_string()))),
229 op: "%".to_string(),
230 right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
231 };
232
233 let (expanded, changed) = GroupByAliasExpander::expand_expression(&expr, &aliases);
234
235 assert!(!changed);
236 assert!(matches!(expanded, SqlExpression::BinaryOp { .. }));
237 }
238
239 #[test]
240 fn test_transform_with_no_group_by() {
241 let mut transformer = GroupByAliasExpander::new();
242 let stmt = SelectStatement {
243 group_by: None,
244 ..Default::default()
245 };
246
247 let result = transformer.transform(stmt);
248 assert!(result.is_ok());
249 }
250
251 #[test]
252 fn test_transform_expands_alias() {
253 let mut transformer = GroupByAliasExpander::new();
254
255 let grp_expr = SqlExpression::BinaryOp {
256 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("id".to_string()))),
257 op: "%".to_string(),
258 right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
259 };
260
261 let stmt = SelectStatement {
262 select_items: vec![SelectItem::Expression {
263 expr: grp_expr.clone(),
264 alias: "grp".to_string(),
265 leading_comments: vec![],
266 trailing_comment: None,
267 }],
268 group_by: Some(vec![SqlExpression::Column(ColumnRef::unquoted(
269 "grp".to_string(),
270 ))]),
271 ..Default::default()
272 };
273
274 let result = transformer.transform(stmt).unwrap();
275
276 if let Some(group_by) = &result.group_by {
278 assert_eq!(group_by.len(), 1);
279 assert!(matches!(group_by[0], SqlExpression::BinaryOp { .. }));
281 } else {
282 panic!("Expected GROUP BY clause");
283 }
284
285 assert_eq!(transformer.expansions, 1);
286 }
287
288 #[test]
289 fn test_does_not_expand_table_prefixed_columns() {
290 let aliases = HashMap::from([(
291 "grp".to_string(),
292 SqlExpression::BinaryOp {
293 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("id".to_string()))),
294 op: "%".to_string(),
295 right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
296 },
297 )]);
298
299 let expr = SqlExpression::Column(ColumnRef {
301 name: "grp".to_string(),
302 quote_style: QuoteStyle::None,
303 table_prefix: Some("t".to_string()),
304 });
305
306 let (expanded, changed) = GroupByAliasExpander::expand_expression(&expr, &aliases);
307
308 assert!(!changed);
309 assert!(matches!(expanded, SqlExpression::Column(_)));
310 }
311
312 #[test]
313 fn test_multiple_aliases_in_group_by() {
314 let mut transformer = GroupByAliasExpander::new();
315
316 let year_expr = SqlExpression::FunctionCall {
317 name: "YEAR".to_string(),
318 args: vec![SqlExpression::Column(ColumnRef::unquoted(
319 "date".to_string(),
320 ))],
321 distinct: false,
322 };
323
324 let month_expr = SqlExpression::FunctionCall {
325 name: "MONTH".to_string(),
326 args: vec![SqlExpression::Column(ColumnRef::unquoted(
327 "date".to_string(),
328 ))],
329 distinct: false,
330 };
331
332 let stmt = SelectStatement {
333 select_items: vec![
334 SelectItem::Expression {
335 expr: year_expr.clone(),
336 alias: "yr".to_string(),
337 leading_comments: vec![],
338 trailing_comment: None,
339 },
340 SelectItem::Expression {
341 expr: month_expr.clone(),
342 alias: "mon".to_string(),
343 leading_comments: vec![],
344 trailing_comment: None,
345 },
346 ],
347 group_by: Some(vec![
348 SqlExpression::Column(ColumnRef::unquoted("yr".to_string())),
349 SqlExpression::Column(ColumnRef::unquoted("mon".to_string())),
350 ]),
351 ..Default::default()
352 };
353
354 let result = transformer.transform(stmt).unwrap();
355
356 if let Some(group_by) = &result.group_by {
358 assert_eq!(group_by.len(), 2);
359 assert!(matches!(group_by[0], SqlExpression::FunctionCall { .. }));
360 assert!(matches!(group_by[1], SqlExpression::FunctionCall { .. }));
361 } else {
362 panic!("Expected GROUP BY clause");
363 }
364
365 assert_eq!(transformer.expansions, 2);
366 }
367}