Skip to main content

sql_cli/query_plan/
group_by_alias_expander.rs

1//! GROUP BY clause alias expansion transformer
2//!
3//! This transformer allows users to reference SELECT clause aliases in GROUP BY clauses
4//! by automatically expanding those aliases to their full expressions.
5//!
6//! # Problem
7//!
8//! Users often want to group by a complex expression using its alias:
9//! ```sql
10//! SELECT id % 3 as grp, COUNT(*) FROM t GROUP BY grp
11//! ```
12//!
13//! This fails because GROUP BY is evaluated before SELECT, so aliases don't exist yet.
14//!
15//! # Solution
16//!
17//! The transformer rewrites to:
18//! ```sql
19//! SELECT id % 3 as grp, COUNT(*) FROM t GROUP BY id % 3
20//! ```
21//!
22//! # Algorithm
23//!
24//! 1. Extract all aliases from SELECT clause and their corresponding expressions
25//! 2. Scan GROUP BY clause for column references
26//! 3. If a column reference matches an alias name, replace it with the full expression
27//! 4. Only expand simple column references (not qualified table.column)
28//!
29//! # Limitations
30//!
31//! - Only works for simple column aliases (not table.alias references)
32//! - Aliases take precedence over actual column names if they conflict
33//! - Complex expressions are duplicated (no common subexpression elimination)
34
35use crate::query_plan::pipeline::ASTTransformer;
36use crate::sql::parser::ast::{CTEType, SelectItem, SelectStatement, SqlExpression, TableSource};
37use anyhow::Result;
38use std::collections::HashMap;
39use tracing::debug;
40
41/// Transformer that expands SELECT aliases in GROUP BY clauses
42pub struct GroupByAliasExpander {
43    /// Counter for tracking number of expansions
44    expansions: usize,
45}
46
47impl GroupByAliasExpander {
48    pub fn new() -> Self {
49        Self { expansions: 0 }
50    }
51
52    /// Extract aliases from SELECT clause
53    /// Returns a map of alias name -> expression
54    fn extract_aliases(select_items: &[SelectItem]) -> HashMap<String, SqlExpression> {
55        let mut aliases = HashMap::new();
56
57        for item in select_items {
58            if let SelectItem::Expression { expr, alias, .. } = item {
59                if !alias.is_empty() {
60                    aliases.insert(alias.clone(), expr.clone());
61                    debug!("Found SELECT alias: {} -> {:?}", alias, expr);
62                }
63            }
64        }
65
66        aliases
67    }
68
69    /// Expand aliases in a single GROUP BY expression
70    /// Returns the expanded expression and whether any expansion occurred
71    fn expand_expression(
72        expr: &SqlExpression,
73        aliases: &HashMap<String, SqlExpression>,
74    ) -> (SqlExpression, bool) {
75        match expr {
76            // Check if this column reference is actually an alias
77            SqlExpression::Column(col_ref) => {
78                // Only expand if it's a simple column (no table prefix)
79                if col_ref.table_prefix.is_none() {
80                    if let Some(alias_expr) = aliases.get(&col_ref.name) {
81                        debug!(
82                            "Expanding alias '{}' in GROUP BY to: {:?}",
83                            col_ref.name, alias_expr
84                        );
85                        return (alias_expr.clone(), true);
86                    }
87                }
88                (expr.clone(), false)
89            }
90
91            // For all other expressions (functions, binary ops, etc.), return as-is
92            // GROUP BY typically uses simple column references or the full expression
93            _ => (expr.clone(), false),
94        }
95    }
96
97    /// Expand aliases in GROUP BY clause
98    fn expand_group_by(
99        &mut self,
100        group_by: &mut Vec<SqlExpression>,
101        aliases: &HashMap<String, SqlExpression>,
102    ) -> bool {
103        let mut any_expanded = false;
104
105        for expr in group_by.iter_mut() {
106            let (new_expr, expanded) = Self::expand_expression(expr, aliases);
107            if expanded {
108                *expr = new_expr;
109                any_expanded = true;
110                self.expansions += 1;
111            }
112        }
113
114        any_expanded
115    }
116
117    /// Transform a SelectStatement and recurse into nested SELECT statements
118    /// (CTEs, FROM subqueries, set operations). Needed because GROUP BY alias
119    /// expansion must happen in every scope where GROUP BY appears.
120    #[allow(deprecated)]
121    fn transform_statement(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
122        // Recurse into CTEs
123        for cte in stmt.ctes.iter_mut() {
124            if let CTEType::Standard(ref mut inner) = cte.cte_type {
125                let taken = std::mem::take(inner);
126                *inner = self.transform_statement(taken)?;
127            }
128        }
129
130        // Recurse into FROM DerivedTable subqueries
131        if let Some(TableSource::DerivedTable { query, .. }) = stmt.from_source.as_mut() {
132            let taken = std::mem::take(query.as_mut());
133            **query = self.transform_statement(taken)?;
134        }
135
136        // Recurse into legacy from_subquery
137        if let Some(subq) = stmt.from_subquery.as_mut() {
138            let taken = std::mem::take(subq.as_mut());
139            **subq = self.transform_statement(taken)?;
140        }
141
142        // Recurse into set operation right-hand sides
143        for (_op, rhs) in stmt.set_operations.iter_mut() {
144            let taken = std::mem::take(rhs.as_mut());
145            **rhs = self.transform_statement(taken)?;
146        }
147
148        // Apply GROUP BY alias expansion at this level
149        self.apply_expansion(&mut stmt);
150
151        Ok(stmt)
152    }
153
154    /// Apply GROUP BY alias expansion to a single statement (no recursion).
155    fn apply_expansion(&mut self, stmt: &mut SelectStatement) {
156        if stmt.group_by.is_none() {
157            return;
158        }
159
160        let aliases = Self::extract_aliases(&stmt.select_items);
161        if aliases.is_empty() {
162            return;
163        }
164
165        if let Some(ref mut group_by) = stmt.group_by {
166            let expanded = self.expand_group_by(group_by, &aliases);
167            if expanded {
168                debug!(
169                    "Expanded {} alias reference(s) in GROUP BY clause",
170                    self.expansions
171                );
172            }
173        }
174    }
175}
176
177impl Default for GroupByAliasExpander {
178    fn default() -> Self {
179        Self::new()
180    }
181}
182
183impl ASTTransformer for GroupByAliasExpander {
184    fn name(&self) -> &str {
185        "GroupByAliasExpander"
186    }
187
188    fn description(&self) -> &str {
189        "Expands SELECT aliases in GROUP BY clauses to their full expressions"
190    }
191
192    fn transform(&mut self, stmt: SelectStatement) -> Result<SelectStatement> {
193        self.transform_statement(stmt)
194    }
195
196    fn begin(&mut self) -> Result<()> {
197        // Reset expansion counter for each query
198        self.expansions = 0;
199        Ok(())
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use crate::sql::parser::ast::{ColumnRef, QuoteStyle};
207
208    #[test]
209    fn test_extract_aliases() {
210        let grp_expr = SqlExpression::BinaryOp {
211            left: Box::new(SqlExpression::Column(ColumnRef {
212                name: "id".to_string(),
213                quote_style: QuoteStyle::None,
214                table_prefix: None,
215            })),
216            op: "%".to_string(),
217            right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
218        };
219
220        let select_items = vec![SelectItem::Expression {
221            expr: grp_expr.clone(),
222            alias: "grp".to_string(),
223            leading_comments: vec![],
224            trailing_comment: None,
225        }];
226
227        let aliases = GroupByAliasExpander::extract_aliases(&select_items);
228        assert_eq!(aliases.len(), 1);
229        assert!(aliases.contains_key("grp"));
230    }
231
232    #[test]
233    fn test_expand_simple_column_reference() {
234        let aliases = HashMap::from([(
235            "grp".to_string(),
236            SqlExpression::BinaryOp {
237                left: Box::new(SqlExpression::Column(ColumnRef::unquoted("id".to_string()))),
238                op: "%".to_string(),
239                right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
240            },
241        )]);
242
243        let expr = SqlExpression::Column(ColumnRef::unquoted("grp".to_string()));
244        let (expanded, changed) = GroupByAliasExpander::expand_expression(&expr, &aliases);
245
246        assert!(changed);
247        assert!(matches!(expanded, SqlExpression::BinaryOp { .. }));
248    }
249
250    #[test]
251    fn test_does_not_expand_full_expressions() {
252        let aliases = HashMap::from([(
253            "grp".to_string(),
254            SqlExpression::BinaryOp {
255                left: Box::new(SqlExpression::Column(ColumnRef::unquoted("id".to_string()))),
256                op: "%".to_string(),
257                right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
258            },
259        )]);
260
261        // Full expression should not be expanded (it's not a simple column reference)
262        let expr = SqlExpression::BinaryOp {
263            left: Box::new(SqlExpression::Column(ColumnRef::unquoted("id".to_string()))),
264            op: "%".to_string(),
265            right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
266        };
267
268        let (expanded, changed) = GroupByAliasExpander::expand_expression(&expr, &aliases);
269
270        assert!(!changed);
271        assert!(matches!(expanded, SqlExpression::BinaryOp { .. }));
272    }
273
274    #[test]
275    fn test_transform_with_no_group_by() {
276        let mut transformer = GroupByAliasExpander::new();
277        let stmt = SelectStatement {
278            group_by: None,
279            ..Default::default()
280        };
281
282        let result = transformer.transform(stmt);
283        assert!(result.is_ok());
284    }
285
286    #[test]
287    fn test_transform_expands_alias() {
288        let mut transformer = GroupByAliasExpander::new();
289
290        let grp_expr = SqlExpression::BinaryOp {
291            left: Box::new(SqlExpression::Column(ColumnRef::unquoted("id".to_string()))),
292            op: "%".to_string(),
293            right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
294        };
295
296        let stmt = SelectStatement {
297            select_items: vec![SelectItem::Expression {
298                expr: grp_expr.clone(),
299                alias: "grp".to_string(),
300                leading_comments: vec![],
301                trailing_comment: None,
302            }],
303            group_by: Some(vec![SqlExpression::Column(ColumnRef::unquoted(
304                "grp".to_string(),
305            ))]),
306            ..Default::default()
307        };
308
309        let result = transformer.transform(stmt).unwrap();
310
311        // Check that GROUP BY was rewritten
312        if let Some(group_by) = &result.group_by {
313            assert_eq!(group_by.len(), 1);
314            // Should now be the expanded expression (id % 3), not the column "grp"
315            assert!(matches!(group_by[0], SqlExpression::BinaryOp { .. }));
316        } else {
317            panic!("Expected GROUP BY clause");
318        }
319
320        assert_eq!(transformer.expansions, 1);
321    }
322
323    #[test]
324    fn test_does_not_expand_table_prefixed_columns() {
325        let aliases = HashMap::from([(
326            "grp".to_string(),
327            SqlExpression::BinaryOp {
328                left: Box::new(SqlExpression::Column(ColumnRef::unquoted("id".to_string()))),
329                op: "%".to_string(),
330                right: Box::new(SqlExpression::NumberLiteral("3".to_string())),
331            },
332        )]);
333
334        // Column with table prefix should NOT be expanded
335        let expr = SqlExpression::Column(ColumnRef {
336            name: "grp".to_string(),
337            quote_style: QuoteStyle::None,
338            table_prefix: Some("t".to_string()),
339        });
340
341        let (expanded, changed) = GroupByAliasExpander::expand_expression(&expr, &aliases);
342
343        assert!(!changed);
344        assert!(matches!(expanded, SqlExpression::Column(_)));
345    }
346
347    #[test]
348    fn test_multiple_aliases_in_group_by() {
349        let mut transformer = GroupByAliasExpander::new();
350
351        let year_expr = SqlExpression::FunctionCall {
352            name: "YEAR".to_string(),
353            args: vec![SqlExpression::Column(ColumnRef::unquoted(
354                "date".to_string(),
355            ))],
356            distinct: false,
357        };
358
359        let month_expr = SqlExpression::FunctionCall {
360            name: "MONTH".to_string(),
361            args: vec![SqlExpression::Column(ColumnRef::unquoted(
362                "date".to_string(),
363            ))],
364            distinct: false,
365        };
366
367        let stmt = SelectStatement {
368            select_items: vec![
369                SelectItem::Expression {
370                    expr: year_expr.clone(),
371                    alias: "yr".to_string(),
372                    leading_comments: vec![],
373                    trailing_comment: None,
374                },
375                SelectItem::Expression {
376                    expr: month_expr.clone(),
377                    alias: "mon".to_string(),
378                    leading_comments: vec![],
379                    trailing_comment: None,
380                },
381            ],
382            group_by: Some(vec![
383                SqlExpression::Column(ColumnRef::unquoted("yr".to_string())),
384                SqlExpression::Column(ColumnRef::unquoted("mon".to_string())),
385            ]),
386            ..Default::default()
387        };
388
389        let result = transformer.transform(stmt).unwrap();
390
391        // Check that both GROUP BY expressions were expanded
392        if let Some(group_by) = &result.group_by {
393            assert_eq!(group_by.len(), 2);
394            assert!(matches!(group_by[0], SqlExpression::FunctionCall { .. }));
395            assert!(matches!(group_by[1], SqlExpression::FunctionCall { .. }));
396        } else {
397            panic!("Expected GROUP BY clause");
398        }
399
400        assert_eq!(transformer.expansions, 2);
401    }
402}