sql_cli/query_plan/
group_by_alias_expander.rs

1//! GROUP BY clause alias expansion transformer
2//!
3//! This transformer allows users to reference SELECT clause aliases in GROUP BY clauses
4//! by automatically expanding those aliases to their full expressions.
5//!
6//! # Problem
7//!
8//! Users often want to group by a complex expression using its alias:
9//! ```sql
10//! SELECT id % 3 as grp, COUNT(*) FROM t GROUP BY grp
11//! ```
12//!
13//! This fails because GROUP BY is evaluated before SELECT, so aliases don't exist yet.
14//!
15//! # Solution
16//!
17//! The transformer rewrites to:
18//! ```sql
19//! SELECT id % 3 as grp, COUNT(*) FROM t GROUP BY id % 3
20//! ```
21//!
22//! # Algorithm
23//!
24//! 1. Extract all aliases from SELECT clause and their corresponding expressions
25//! 2. Scan GROUP BY clause for column references
26//! 3. If a column reference matches an alias name, replace it with the full expression
27//! 4. Only expand simple column references (not qualified table.column)
28//!
29//! # Limitations
30//!
31//! - Only works for simple column aliases (not table.alias references)
32//! - Aliases take precedence over actual column names if they conflict
33//! - Complex expressions are duplicated (no common subexpression elimination)
34
35use crate::query_plan::pipeline::ASTTransformer;
36use crate::sql::parser::ast::{SelectItem, SelectStatement, SqlExpression};
37use anyhow::Result;
38use std::collections::HashMap;
39use tracing::debug;
40
41/// Transformer that expands SELECT aliases in GROUP BY clauses
42pub struct GroupByAliasExpander {
43    /// Counter for tracking number of expansions
44    expansions: usize,
45}
46
47impl GroupByAliasExpander {
48    pub fn new() -> Self {
49        Self { expansions: 0 }
50    }
51
52    /// Extract aliases from SELECT clause
53    /// Returns a map of alias name -> expression
54    fn extract_aliases(select_items: &[SelectItem]) -> HashMap<String, SqlExpression> {
55        let mut aliases = HashMap::new();
56
57        for item in select_items {
58            if let SelectItem::Expression { expr, alias, .. } = item {
59                if !alias.is_empty() {
60                    aliases.insert(alias.clone(), expr.clone());
61                    debug!("Found SELECT alias: {} -> {:?}", alias, expr);
62                }
63            }
64        }
65
66        aliases
67    }
68
69    /// Expand aliases in a single GROUP BY expression
70    /// Returns the expanded expression and whether any expansion occurred
71    fn expand_expression(
72        expr: &SqlExpression,
73        aliases: &HashMap<String, SqlExpression>,
74    ) -> (SqlExpression, bool) {
75        match expr {
76            // Check if this column reference is actually an alias
77            SqlExpression::Column(col_ref) => {
78                // Only expand if it's a simple column (no table prefix)
79                if col_ref.table_prefix.is_none() {
80                    if let Some(alias_expr) = aliases.get(&col_ref.name) {
81                        debug!(
82                            "Expanding alias '{}' in GROUP BY to: {:?}",
83                            col_ref.name, alias_expr
84                        );
85                        return (alias_expr.clone(), true);
86                    }
87                }
88                (expr.clone(), false)
89            }
90
91            // For all other expressions (functions, binary ops, etc.), return as-is
92            // GROUP BY typically uses simple column references or the full expression
93            _ => (expr.clone(), false),
94        }
95    }
96
97    /// Expand aliases in GROUP BY clause
98    fn expand_group_by(
99        &mut self,
100        group_by: &mut Vec<SqlExpression>,
101        aliases: &HashMap<String, SqlExpression>,
102    ) -> bool {
103        let mut any_expanded = false;
104
105        for expr in group_by.iter_mut() {
106            let (new_expr, expanded) = Self::expand_expression(expr, aliases);
107            if expanded {
108                *expr = new_expr;
109                any_expanded = true;
110                self.expansions += 1;
111            }
112        }
113
114        any_expanded
115    }
116}
117
118impl Default for GroupByAliasExpander {
119    fn default() -> Self {
120        Self::new()
121    }
122}
123
124impl ASTTransformer for GroupByAliasExpander {
125    fn name(&self) -> &str {
126        "GroupByAliasExpander"
127    }
128
129    fn description(&self) -> &str {
130        "Expands SELECT aliases in GROUP BY clauses to their full expressions"
131    }
132
133    fn transform(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
134        // Only process if there's a GROUP BY clause
135        if stmt.group_by.is_none() {
136            return Ok(stmt);
137        }
138
139        // Step 1: Extract all aliases from SELECT clause
140        let aliases = Self::extract_aliases(&stmt.select_items);
141
142        if aliases.is_empty() {
143            // No aliases to expand
144            return Ok(stmt);
145        }
146
147        // Step 2: Expand aliases in GROUP BY clause
148        if let Some(ref mut group_by) = stmt.group_by {
149            let expanded = self.expand_group_by(group_by, &aliases);
150            if expanded {
151                debug!(
152                    "Expanded {} alias reference(s) in GROUP BY clause",
153                    self.expansions
154                );
155            }
156        }
157
158        Ok(stmt)
159    }
160
161    fn begin(&mut self) -> Result<()> {
162        // Reset expansion counter for each query
163        self.expansions = 0;
164        Ok(())
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use crate::sql::parser::ast::{ColumnRef, QuoteStyle};
172
173    #[test]
174    fn test_extract_aliases() {
175        let grp_expr = SqlExpression::BinaryOp {
176            left: Box::new(SqlExpression::Column(ColumnRef {
177                name: "id".to_string(),
178                quote_style: QuoteStyle::None,
179                table_prefix: None,
180            })),
181            op: "%".to_string(),
182            right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
183        };
184
185        let select_items = vec![SelectItem::Expression {
186            expr: grp_expr.clone(),
187            alias: "grp".to_string(),
188            leading_comments: vec![],
189            trailing_comment: None,
190        }];
191
192        let aliases = GroupByAliasExpander::extract_aliases(&select_items);
193        assert_eq!(aliases.len(), 1);
194        assert!(aliases.contains_key("grp"));
195    }
196
197    #[test]
198    fn test_expand_simple_column_reference() {
199        let aliases = HashMap::from([(
200            "grp".to_string(),
201            SqlExpression::BinaryOp {
202                left: Box::new(SqlExpression::Column(ColumnRef::unquoted("id".to_string()))),
203                op: "%".to_string(),
204                right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
205            },
206        )]);
207
208        let expr = SqlExpression::Column(ColumnRef::unquoted("grp".to_string()));
209        let (expanded, changed) = GroupByAliasExpander::expand_expression(&expr, &aliases);
210
211        assert!(changed);
212        assert!(matches!(expanded, SqlExpression::BinaryOp { .. }));
213    }
214
215    #[test]
216    fn test_does_not_expand_full_expressions() {
217        let aliases = HashMap::from([(
218            "grp".to_string(),
219            SqlExpression::BinaryOp {
220                left: Box::new(SqlExpression::Column(ColumnRef::unquoted("id".to_string()))),
221                op: "%".to_string(),
222                right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
223            },
224        )]);
225
226        // Full expression should not be expanded (it's not a simple column reference)
227        let expr = SqlExpression::BinaryOp {
228            left: Box::new(SqlExpression::Column(ColumnRef::unquoted("id".to_string()))),
229            op: "%".to_string(),
230            right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
231        };
232
233        let (expanded, changed) = GroupByAliasExpander::expand_expression(&expr, &aliases);
234
235        assert!(!changed);
236        assert!(matches!(expanded, SqlExpression::BinaryOp { .. }));
237    }
238
239    #[test]
240    fn test_transform_with_no_group_by() {
241        let mut transformer = GroupByAliasExpander::new();
242        let stmt = SelectStatement {
243            group_by: None,
244            ..Default::default()
245        };
246
247        let result = transformer.transform(stmt);
248        assert!(result.is_ok());
249    }
250
251    #[test]
252    fn test_transform_expands_alias() {
253        let mut transformer = GroupByAliasExpander::new();
254
255        let grp_expr = SqlExpression::BinaryOp {
256            left: Box::new(SqlExpression::Column(ColumnRef::unquoted("id".to_string()))),
257            op: "%".to_string(),
258            right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
259        };
260
261        let stmt = SelectStatement {
262            select_items: vec![SelectItem::Expression {
263                expr: grp_expr.clone(),
264                alias: "grp".to_string(),
265                leading_comments: vec![],
266                trailing_comment: None,
267            }],
268            group_by: Some(vec![SqlExpression::Column(ColumnRef::unquoted(
269                "grp".to_string(),
270            ))]),
271            ..Default::default()
272        };
273
274        let result = transformer.transform(stmt).unwrap();
275
276        // Check that GROUP BY was rewritten
277        if let Some(group_by) = &result.group_by {
278            assert_eq!(group_by.len(), 1);
279            // Should now be the expanded expression (id % 3), not the column "grp"
280            assert!(matches!(group_by[0], SqlExpression::BinaryOp { .. }));
281        } else {
282            panic!("Expected GROUP BY clause");
283        }
284
285        assert_eq!(transformer.expansions, 1);
286    }
287
288    #[test]
289    fn test_does_not_expand_table_prefixed_columns() {
290        let aliases = HashMap::from([(
291            "grp".to_string(),
292            SqlExpression::BinaryOp {
293                left: Box::new(SqlExpression::Column(ColumnRef::unquoted("id".to_string()))),
294                op: "%".to_string(),
295                right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
296            },
297        )]);
298
299        // Column with table prefix should NOT be expanded
300        let expr = SqlExpression::Column(ColumnRef {
301            name: "grp".to_string(),
302            quote_style: QuoteStyle::None,
303            table_prefix: Some("t".to_string()),
304        });
305
306        let (expanded, changed) = GroupByAliasExpander::expand_expression(&expr, &aliases);
307
308        assert!(!changed);
309        assert!(matches!(expanded, SqlExpression::Column(_)));
310    }
311
312    #[test]
313    fn test_multiple_aliases_in_group_by() {
314        let mut transformer = GroupByAliasExpander::new();
315
316        let year_expr = SqlExpression::FunctionCall {
317            name: "YEAR".to_string(),
318            args: vec![SqlExpression::Column(ColumnRef::unquoted(
319                "date".to_string(),
320            ))],
321            distinct: false,
322        };
323
324        let month_expr = SqlExpression::FunctionCall {
325            name: "MONTH".to_string(),
326            args: vec![SqlExpression::Column(ColumnRef::unquoted(
327                "date".to_string(),
328            ))],
329            distinct: false,
330        };
331
332        let stmt = SelectStatement {
333            select_items: vec![
334                SelectItem::Expression {
335                    expr: year_expr.clone(),
336                    alias: "yr".to_string(),
337                    leading_comments: vec![],
338                    trailing_comment: None,
339                },
340                SelectItem::Expression {
341                    expr: month_expr.clone(),
342                    alias: "mon".to_string(),
343                    leading_comments: vec![],
344                    trailing_comment: None,
345                },
346            ],
347            group_by: Some(vec![
348                SqlExpression::Column(ColumnRef::unquoted("yr".to_string())),
349                SqlExpression::Column(ColumnRef::unquoted("mon".to_string())),
350            ]),
351            ..Default::default()
352        };
353
354        let result = transformer.transform(stmt).unwrap();
355
356        // Check that both GROUP BY expressions were expanded
357        if let Some(group_by) = &result.group_by {
358            assert_eq!(group_by.len(), 2);
359            assert!(matches!(group_by[0], SqlExpression::FunctionCall { .. }));
360            assert!(matches!(group_by[1], SqlExpression::FunctionCall { .. }));
361        } else {
362            panic!("Expected GROUP BY clause");
363        }
364
365        assert_eq!(transformer.expansions, 2);
366    }
367}