sql_cli/query_plan/
having_alias_transformer.rs

1//! HAVING clause auto-aliasing transformer
2//!
3//! This transformer automatically adds aliases to aggregate functions in SELECT
4//! clauses and rewrites HAVING clauses to use those aliases instead of the
5//! aggregate function expressions.
6//!
7//! # Problem
8//!
9//! Users often write queries like:
10//! ```sql
11//! SELECT region, COUNT(*) FROM sales GROUP BY region HAVING COUNT(*) > 5
12//! ```
13//!
14//! This fails because the executor can't evaluate `COUNT(*)` in the HAVING
15//! clause - it needs a column reference.
16//!
17//! # Solution
18//!
19//! The transformer rewrites to:
20//! ```sql
21//! SELECT region, COUNT(*) as __agg_1 FROM sales GROUP BY region HAVING __agg_1 > 5
22//! ```
23//!
24//! # Algorithm
25//!
26//! 1. Find all aggregate functions in SELECT clause
27//! 2. For each aggregate without an explicit alias, generate one (__agg_N)
28//! 3. Scan HAVING clause for matching aggregate expressions
29//! 4. Replace aggregate expressions with column references to the aliases
30
31use crate::query_plan::pipeline::ASTTransformer;
32use crate::sql::parser::ast::{ColumnRef, QuoteStyle, SelectItem, SelectStatement, SqlExpression};
33use anyhow::Result;
34use std::collections::HashMap;
35use tracing::debug;
36
37/// Transformer that adds aliases to aggregates and rewrites HAVING clauses
38pub struct HavingAliasTransformer {
39    /// Counter for generating unique alias names
40    alias_counter: usize,
41}
42
43impl HavingAliasTransformer {
44    pub fn new() -> Self {
45        Self { alias_counter: 0 }
46    }
47
48    /// Check if an expression is an aggregate function
49    fn is_aggregate_function(expr: &SqlExpression) -> bool {
50        matches!(
51            expr,
52            SqlExpression::FunctionCall { name, .. }
53                if matches!(
54                    name.to_uppercase().as_str(),
55                    "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "COUNT_DISTINCT"
56                )
57        )
58    }
59
60    /// Generate a unique alias name
61    fn generate_alias(&mut self) -> String {
62        self.alias_counter += 1;
63        format!("__agg_{}", self.alias_counter)
64    }
65
66    /// Normalize an aggregate expression to a canonical form for comparison
67    fn normalize_aggregate_expr(expr: &SqlExpression) -> String {
68        match expr {
69            SqlExpression::FunctionCall { name, args, .. } => {
70                let args_str = args
71                    .iter()
72                    .map(|arg| match arg {
73                        SqlExpression::Column(col_ref) => {
74                            format!("{}", col_ref.name)
75                        }
76                        SqlExpression::StringLiteral(s) => format!("'{}'", s),
77                        SqlExpression::NumberLiteral(n) => n.clone(),
78                        _ => format!("{:?}", arg), // Fallback for complex args
79                    })
80                    .collect::<Vec<_>>()
81                    .join(",");
82                format!("{}({})", name.to_uppercase(), args_str)
83            }
84            _ => format!("{:?}", expr),
85        }
86    }
87
88    /// Extract aggregate functions from SELECT clause and ensure they have aliases
89    fn ensure_aggregate_aliases(
90        &mut self,
91        select_items: &mut Vec<SelectItem>,
92    ) -> HashMap<String, String> {
93        let mut aggregate_map = HashMap::new();
94
95        for item in select_items.iter_mut() {
96            if let SelectItem::Expression { expr, alias, .. } = item {
97                if Self::is_aggregate_function(expr) {
98                    // Generate alias if none exists
99                    if alias.is_empty() {
100                        *alias = self.generate_alias();
101                        debug!(
102                            "Generated alias '{}' for aggregate: {}",
103                            alias,
104                            Self::normalize_aggregate_expr(expr)
105                        );
106                    }
107
108                    // Map normalized expression to alias
109                    let normalized = Self::normalize_aggregate_expr(expr);
110                    aggregate_map.insert(normalized, alias.clone());
111                }
112            }
113        }
114
115        aggregate_map
116    }
117
118    /// Rewrite a HAVING expression to use aliases instead of aggregates
119    fn rewrite_having_expression(
120        expr: &SqlExpression,
121        aggregate_map: &HashMap<String, String>,
122    ) -> SqlExpression {
123        match expr {
124            SqlExpression::FunctionCall { .. } if Self::is_aggregate_function(expr) => {
125                let normalized = Self::normalize_aggregate_expr(expr);
126                if let Some(alias) = aggregate_map.get(&normalized) {
127                    debug!(
128                        "Rewriting aggregate {} to column reference {}",
129                        normalized, alias
130                    );
131                    SqlExpression::Column(ColumnRef {
132                        name: alias.clone(),
133                        quote_style: QuoteStyle::None,
134                        table_prefix: None,
135                    })
136                } else {
137                    // Aggregate not found in SELECT - leave as is (will fail later with clear error)
138                    expr.clone()
139                }
140            }
141            SqlExpression::BinaryOp { left, op, right } => SqlExpression::BinaryOp {
142                left: Box::new(Self::rewrite_having_expression(left, aggregate_map)),
143                op: op.clone(),
144                right: Box::new(Self::rewrite_having_expression(right, aggregate_map)),
145            },
146            // For other expressions, return as-is
147            _ => expr.clone(),
148        }
149    }
150}
151
152impl Default for HavingAliasTransformer {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158impl ASTTransformer for HavingAliasTransformer {
159    fn name(&self) -> &str {
160        "HavingAliasTransformer"
161    }
162
163    fn description(&self) -> &str {
164        "Adds aliases to aggregate functions and rewrites HAVING clauses to use them"
165    }
166
167    fn transform(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
168        // Only process if there's a HAVING clause
169        if stmt.having.is_none() {
170            return Ok(stmt);
171        }
172
173        // Step 1: Ensure all aggregates in SELECT have aliases and build mapping
174        let aggregate_map = self.ensure_aggregate_aliases(&mut stmt.select_items);
175
176        if aggregate_map.is_empty() {
177            // No aggregates found, nothing to do
178            return Ok(stmt);
179        }
180
181        // Step 2: Rewrite HAVING clause to use aliases
182        if let Some(having_expr) = stmt.having.take() {
183            let rewritten = Self::rewrite_having_expression(&having_expr, &aggregate_map);
184
185            // Only set if something changed
186            if format!("{:?}", having_expr) != format!("{:?}", rewritten) {
187                debug!(
188                    "Rewrote HAVING clause with {} aggregate alias(es)",
189                    aggregate_map.len()
190                );
191                stmt.having = Some(rewritten);
192            } else {
193                stmt.having = Some(having_expr);
194            }
195        }
196
197        Ok(stmt)
198    }
199
200    fn begin(&mut self) -> Result<()> {
201        // Reset counter for each query
202        self.alias_counter = 0;
203        Ok(())
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn test_is_aggregate_function() {
213        let count_expr = SqlExpression::FunctionCall {
214            name: "COUNT".to_string(),
215            args: vec![SqlExpression::Column(ColumnRef {
216                name: "*".to_string(),
217                quote_style: QuoteStyle::None,
218                table_prefix: None,
219            })],
220            distinct: false,
221        };
222        assert!(HavingAliasTransformer::is_aggregate_function(&count_expr));
223
224        let sum_expr = SqlExpression::FunctionCall {
225            name: "SUM".to_string(),
226            args: vec![SqlExpression::Column(ColumnRef {
227                name: "amount".to_string(),
228                quote_style: QuoteStyle::None,
229                table_prefix: None,
230            })],
231            distinct: false,
232        };
233        assert!(HavingAliasTransformer::is_aggregate_function(&sum_expr));
234
235        let non_agg = SqlExpression::FunctionCall {
236            name: "UPPER".to_string(),
237            args: vec![],
238            distinct: false,
239        };
240        assert!(!HavingAliasTransformer::is_aggregate_function(&non_agg));
241    }
242
243    #[test]
244    fn test_normalize_aggregate_expr() {
245        let count_star = SqlExpression::FunctionCall {
246            name: "count".to_string(),
247            args: vec![SqlExpression::Column(ColumnRef {
248                name: "*".to_string(),
249                quote_style: QuoteStyle::None,
250                table_prefix: None,
251            })],
252            distinct: false,
253        };
254        assert_eq!(
255            HavingAliasTransformer::normalize_aggregate_expr(&count_star),
256            "COUNT(*)"
257        );
258
259        let sum_amount = SqlExpression::FunctionCall {
260            name: "SUM".to_string(),
261            args: vec![SqlExpression::Column(ColumnRef {
262                name: "amount".to_string(),
263                quote_style: QuoteStyle::None,
264                table_prefix: None,
265            })],
266            distinct: false,
267        };
268        assert_eq!(
269            HavingAliasTransformer::normalize_aggregate_expr(&sum_amount),
270            "SUM(amount)"
271        );
272    }
273
274    #[test]
275    fn test_generate_alias() {
276        let mut transformer = HavingAliasTransformer::new();
277        assert_eq!(transformer.generate_alias(), "__agg_1");
278        assert_eq!(transformer.generate_alias(), "__agg_2");
279        assert_eq!(transformer.generate_alias(), "__agg_3");
280    }
281
282    #[test]
283    fn test_transform_with_no_having() {
284        let mut transformer = HavingAliasTransformer::new();
285        let stmt = SelectStatement {
286            having: None,
287            ..Default::default()
288        };
289
290        let result = transformer.transform(stmt);
291        assert!(result.is_ok());
292    }
293
294    #[test]
295    fn test_transform_adds_alias_and_rewrites_having() {
296        let mut transformer = HavingAliasTransformer::new();
297
298        let count_expr = SqlExpression::FunctionCall {
299            name: "COUNT".to_string(),
300            args: vec![SqlExpression::Column(ColumnRef {
301                name: "*".to_string(),
302                quote_style: QuoteStyle::None,
303                table_prefix: None,
304            })],
305            distinct: false,
306        };
307
308        let stmt = SelectStatement {
309            select_items: vec![SelectItem::Expression {
310                expr: count_expr.clone(),
311                alias: String::new(), // No alias initially
312                leading_comments: Vec::new(),
313                trailing_comment: None,
314            }],
315            having: Some(SqlExpression::BinaryOp {
316                left: Box::new(count_expr.clone()),
317                op: ">".to_string(),
318                right: Box::new(SqlExpression::NumberLiteral("5".to_string())),
319            }),
320            ..Default::default()
321        };
322
323        let result = transformer.transform(stmt).unwrap();
324
325        // Check that alias was added to SELECT
326        if let SelectItem::Expression { alias, .. } = &result.select_items[0] {
327            assert_eq!(alias, "__agg_1");
328        } else {
329            panic!("Expected Expression select item");
330        }
331
332        // Check that HAVING was rewritten to use alias
333        if let Some(SqlExpression::BinaryOp { left, .. }) = &result.having {
334            match left.as_ref() {
335                SqlExpression::Column(col_ref) => {
336                    assert_eq!(col_ref.name, "__agg_1");
337                }
338                _ => panic!("Expected column reference in HAVING, got: {:?}", left),
339            }
340        } else {
341            panic!("Expected BinaryOp in HAVING");
342        }
343    }
344}