Skip to main content

sql_cli/query_plan/
having_alias_transformer.rs

1//! HAVING clause auto-aliasing transformer
2//!
3//! This transformer automatically adds aliases to aggregate functions in SELECT
4//! clauses and rewrites HAVING clauses to use those aliases instead of the
5//! aggregate function expressions.
6//!
7//! # Problem
8//!
9//! Users often write queries like:
10//! ```sql
11//! SELECT region, COUNT(*) FROM sales GROUP BY region HAVING COUNT(*) > 5
12//! ```
13//!
14//! This fails because the executor can't evaluate `COUNT(*)` in the HAVING
15//! clause - it needs a column reference.
16//!
17//! # Solution
18//!
19//! The transformer rewrites to:
20//! ```sql
21//! SELECT region, COUNT(*) as __agg_1 FROM sales GROUP BY region HAVING __agg_1 > 5
22//! ```
23//!
24//! # Algorithm
25//!
26//! 1. Find all aggregate functions in SELECT clause
27//! 2. For each aggregate without an explicit alias, generate one (__agg_N)
28//! 3. Scan HAVING clause for matching aggregate expressions
29//! 4. Replace aggregate expressions with column references to the aliases
30
31use crate::query_plan::pipeline::ASTTransformer;
32use crate::sql::parser::ast::{
33    CTEType, ColumnRef, QuoteStyle, SelectItem, SelectStatement, SqlExpression, TableSource,
34};
35use anyhow::Result;
36use std::collections::HashMap;
37use tracing::debug;
38
39/// Prefix used for aggregates promoted from HAVING into SELECT.
40/// Columns with this prefix are hidden from the final output.
41pub const HIDDEN_AGG_PREFIX: &str = "__hidden_agg_";
42
43/// Transformer that adds aliases to aggregates and rewrites HAVING clauses
44pub struct HavingAliasTransformer {
45    /// Counter for generating unique alias names
46    alias_counter: usize,
47    /// Counter for HAVING-promoted (hidden) aggregates
48    hidden_counter: usize,
49}
50
51impl HavingAliasTransformer {
52    pub fn new() -> Self {
53        Self {
54            alias_counter: 0,
55            hidden_counter: 0,
56        }
57    }
58
59    /// Check if an expression is an aggregate function
60    fn is_aggregate_function(expr: &SqlExpression) -> bool {
61        matches!(
62            expr,
63            SqlExpression::FunctionCall { name, .. }
64                if matches!(
65                    name.to_uppercase().as_str(),
66                    "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "COUNT_DISTINCT"
67                )
68        )
69    }
70
71    /// Generate a unique alias name
72    fn generate_alias(&mut self) -> String {
73        self.alias_counter += 1;
74        format!("__agg_{}", self.alias_counter)
75    }
76
77    /// Normalize an aggregate expression to a canonical form for comparison
78    fn normalize_aggregate_expr(expr: &SqlExpression) -> String {
79        match expr {
80            SqlExpression::FunctionCall { name, args, .. } => {
81                let args_str = args
82                    .iter()
83                    .map(|arg| match arg {
84                        SqlExpression::Column(col_ref) => {
85                            format!("{}", col_ref.name)
86                        }
87                        SqlExpression::StringLiteral(s) => format!("'{}'", s),
88                        SqlExpression::NumberLiteral(n) => n.clone(),
89                        _ => format!("{:?}", arg), // Fallback for complex args
90                    })
91                    .collect::<Vec<_>>()
92                    .join(",");
93                format!("{}({})", name.to_uppercase(), args_str)
94            }
95            _ => format!("{:?}", expr),
96        }
97    }
98
99    /// Extract aggregate functions from SELECT clause and ensure they have aliases
100    fn ensure_aggregate_aliases(
101        &mut self,
102        select_items: &mut Vec<SelectItem>,
103    ) -> HashMap<String, String> {
104        let mut aggregate_map = HashMap::new();
105
106        for item in select_items.iter_mut() {
107            if let SelectItem::Expression { expr, alias, .. } = item {
108                if Self::is_aggregate_function(expr) {
109                    // Generate alias if none exists
110                    if alias.is_empty() {
111                        *alias = self.generate_alias();
112                        debug!(
113                            "Generated alias '{}' for aggregate: {}",
114                            alias,
115                            Self::normalize_aggregate_expr(expr)
116                        );
117                    }
118
119                    // Map normalized expression to alias
120                    let normalized = Self::normalize_aggregate_expr(expr);
121                    aggregate_map.insert(normalized, alias.clone());
122                }
123            }
124        }
125
126        aggregate_map
127    }
128
129    /// Generate a unique hidden alias name for aggregates promoted from HAVING
130    fn generate_hidden_alias(&mut self) -> String {
131        self.hidden_counter += 1;
132        format!("{}{}", HIDDEN_AGG_PREFIX, self.hidden_counter)
133    }
134
135    /// Collect all aggregate function calls from a HAVING expression
136    fn collect_aggregates_in_having(expr: &SqlExpression, found: &mut Vec<SqlExpression>) {
137        match expr {
138            SqlExpression::FunctionCall { args, .. } if Self::is_aggregate_function(expr) => {
139                found.push(expr.clone());
140                // Don't recurse into aggregate args — nested aggregates aren't supported anyway
141                let _ = args;
142            }
143            SqlExpression::BinaryOp { left, right, .. } => {
144                Self::collect_aggregates_in_having(left, found);
145                Self::collect_aggregates_in_having(right, found);
146            }
147            SqlExpression::Not { expr } => {
148                Self::collect_aggregates_in_having(expr, found);
149            }
150            SqlExpression::FunctionCall { args, .. } => {
151                // Non-aggregate function — check its arguments for aggregates
152                for arg in args {
153                    Self::collect_aggregates_in_having(arg, found);
154                }
155            }
156            _ => {}
157        }
158    }
159
160    /// Promote aggregates in HAVING that aren't already in SELECT into hidden
161    /// SELECT items. Returns updated aggregate_map with new entries.
162    fn promote_having_aggregates(
163        &mut self,
164        having_expr: &SqlExpression,
165        select_items: &mut Vec<SelectItem>,
166        aggregate_map: &mut HashMap<String, String>,
167    ) {
168        let mut having_aggs = Vec::new();
169        Self::collect_aggregates_in_having(having_expr, &mut having_aggs);
170
171        for agg in having_aggs {
172            let normalized = Self::normalize_aggregate_expr(&agg);
173            if aggregate_map.contains_key(&normalized) {
174                continue; // Already in SELECT
175            }
176
177            let hidden_alias = self.generate_hidden_alias();
178            debug!(
179                "Promoting HAVING aggregate {} as hidden SELECT item '{}'",
180                normalized, hidden_alias
181            );
182
183            select_items.push(SelectItem::Expression {
184                expr: agg,
185                alias: hidden_alias.clone(),
186                leading_comments: Vec::new(),
187                trailing_comment: None,
188            });
189
190            aggregate_map.insert(normalized, hidden_alias);
191        }
192    }
193
194    /// Rewrite a HAVING expression to use aliases instead of aggregates
195    fn rewrite_having_expression(
196        expr: &SqlExpression,
197        aggregate_map: &HashMap<String, String>,
198    ) -> SqlExpression {
199        match expr {
200            SqlExpression::FunctionCall { .. } if Self::is_aggregate_function(expr) => {
201                let normalized = Self::normalize_aggregate_expr(expr);
202                if let Some(alias) = aggregate_map.get(&normalized) {
203                    debug!(
204                        "Rewriting aggregate {} to column reference {}",
205                        normalized, alias
206                    );
207                    SqlExpression::Column(ColumnRef {
208                        name: alias.clone(),
209                        quote_style: QuoteStyle::None,
210                        table_prefix: None,
211                    })
212                } else {
213                    // Aggregate not found in SELECT - leave as is (will fail later with clear error)
214                    expr.clone()
215                }
216            }
217            SqlExpression::BinaryOp { left, op, right } => SqlExpression::BinaryOp {
218                left: Box::new(Self::rewrite_having_expression(left, aggregate_map)),
219                op: op.clone(),
220                right: Box::new(Self::rewrite_having_expression(right, aggregate_map)),
221            },
222            SqlExpression::Not { expr } => SqlExpression::Not {
223                expr: Box::new(Self::rewrite_having_expression(expr, aggregate_map)),
224            },
225            SqlExpression::FunctionCall {
226                name,
227                args,
228                distinct,
229            } => {
230                // Non-aggregate function — recurse into args
231                SqlExpression::FunctionCall {
232                    name: name.clone(),
233                    args: args
234                        .iter()
235                        .map(|a| Self::rewrite_having_expression(a, aggregate_map))
236                        .collect(),
237                    distinct: *distinct,
238                }
239            }
240            // For other expressions, return as-is
241            _ => expr.clone(),
242        }
243    }
244
245    /// Transform a SelectStatement and recursively apply to nested statements
246    /// (CTEs, FROM subqueries, set operations). This ensures HAVING clauses in
247    /// any nested query — including subqueries in FROM — are properly rewritten.
248    #[allow(deprecated)]
249    fn transform_statement(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
250        // Step A: recurse into CTEs
251        for cte in stmt.ctes.iter_mut() {
252            if let CTEType::Standard(ref mut inner) = cte.cte_type {
253                let taken = std::mem::take(inner);
254                *inner = self.transform_statement(taken)?;
255            }
256        }
257
258        // Step B: recurse into FROM source (DerivedTable)
259        if let Some(TableSource::DerivedTable { query, .. }) = stmt.from_source.as_mut() {
260            let taken = std::mem::take(query.as_mut());
261            **query = self.transform_statement(taken)?;
262        }
263
264        // Step B2: recurse into legacy from_subquery
265        if let Some(subq) = stmt.from_subquery.as_mut() {
266            let taken = std::mem::take(subq.as_mut());
267            **subq = self.transform_statement(taken)?;
268        }
269
270        // Step C: recurse into set operation right-hand sides
271        for (_op, rhs) in stmt.set_operations.iter_mut() {
272            let taken = std::mem::take(rhs.as_mut());
273            **rhs = self.transform_statement(taken)?;
274        }
275
276        // Step D: apply HAVING transformation at this level
277        self.apply_having_rewrite(&mut stmt);
278
279        Ok(stmt)
280    }
281
282    /// Apply HAVING aggregate promotion and rewriting to a single statement
283    /// (no recursion — the caller is responsible for that).
284    fn apply_having_rewrite(&mut self, stmt: &mut SelectStatement) {
285        // Only process if there's a HAVING clause
286        if stmt.having.is_none() {
287            return;
288        }
289
290        // Step 1: Ensure all aggregates in SELECT have aliases and build mapping
291        let mut aggregate_map = self.ensure_aggregate_aliases(&mut stmt.select_items);
292
293        // Step 1b: Promote any HAVING-only aggregates into SELECT with hidden aliases
294        if let Some(ref having_expr) = stmt.having {
295            self.promote_having_aggregates(having_expr, &mut stmt.select_items, &mut aggregate_map);
296        }
297
298        if aggregate_map.is_empty() {
299            return;
300        }
301
302        // Step 2: Rewrite HAVING clause to use aliases
303        if let Some(having_expr) = stmt.having.take() {
304            let rewritten = Self::rewrite_having_expression(&having_expr, &aggregate_map);
305            if format!("{:?}", having_expr) != format!("{:?}", rewritten) {
306                debug!(
307                    "Rewrote HAVING clause with {} aggregate alias(es)",
308                    aggregate_map.len()
309                );
310                stmt.having = Some(rewritten);
311            } else {
312                stmt.having = Some(having_expr);
313            }
314        }
315    }
316}
317
318impl Default for HavingAliasTransformer {
319    fn default() -> Self {
320        Self::new()
321    }
322}
323
324impl ASTTransformer for HavingAliasTransformer {
325    fn name(&self) -> &str {
326        "HavingAliasTransformer"
327    }
328
329    fn description(&self) -> &str {
330        "Adds aliases to aggregate functions and rewrites HAVING clauses to use them"
331    }
332
333    fn transform(&mut self, stmt: SelectStatement) -> Result<SelectStatement> {
334        self.transform_statement(stmt)
335    }
336
337    fn begin(&mut self) -> Result<()> {
338        // Reset counters for each query
339        self.alias_counter = 0;
340        self.hidden_counter = 0;
341        Ok(())
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    #[test]
350    fn test_is_aggregate_function() {
351        let count_expr = SqlExpression::FunctionCall {
352            name: "COUNT".to_string(),
353            args: vec![SqlExpression::Column(ColumnRef {
354                name: "*".to_string(),
355                quote_style: QuoteStyle::None,
356                table_prefix: None,
357            })],
358            distinct: false,
359        };
360        assert!(HavingAliasTransformer::is_aggregate_function(&count_expr));
361
362        let sum_expr = SqlExpression::FunctionCall {
363            name: "SUM".to_string(),
364            args: vec![SqlExpression::Column(ColumnRef {
365                name: "amount".to_string(),
366                quote_style: QuoteStyle::None,
367                table_prefix: None,
368            })],
369            distinct: false,
370        };
371        assert!(HavingAliasTransformer::is_aggregate_function(&sum_expr));
372
373        let non_agg = SqlExpression::FunctionCall {
374            name: "UPPER".to_string(),
375            args: vec![],
376            distinct: false,
377        };
378        assert!(!HavingAliasTransformer::is_aggregate_function(&non_agg));
379    }
380
381    #[test]
382    fn test_normalize_aggregate_expr() {
383        let count_star = SqlExpression::FunctionCall {
384            name: "count".to_string(),
385            args: vec![SqlExpression::Column(ColumnRef {
386                name: "*".to_string(),
387                quote_style: QuoteStyle::None,
388                table_prefix: None,
389            })],
390            distinct: false,
391        };
392        assert_eq!(
393            HavingAliasTransformer::normalize_aggregate_expr(&count_star),
394            "COUNT(*)"
395        );
396
397        let sum_amount = SqlExpression::FunctionCall {
398            name: "SUM".to_string(),
399            args: vec![SqlExpression::Column(ColumnRef {
400                name: "amount".to_string(),
401                quote_style: QuoteStyle::None,
402                table_prefix: None,
403            })],
404            distinct: false,
405        };
406        assert_eq!(
407            HavingAliasTransformer::normalize_aggregate_expr(&sum_amount),
408            "SUM(amount)"
409        );
410    }
411
412    #[test]
413    fn test_generate_alias() {
414        let mut transformer = HavingAliasTransformer::new();
415        assert_eq!(transformer.generate_alias(), "__agg_1");
416        assert_eq!(transformer.generate_alias(), "__agg_2");
417        assert_eq!(transformer.generate_alias(), "__agg_3");
418    }
419
420    #[test]
421    fn test_transform_with_no_having() {
422        let mut transformer = HavingAliasTransformer::new();
423        let stmt = SelectStatement {
424            having: None,
425            ..Default::default()
426        };
427
428        let result = transformer.transform(stmt);
429        assert!(result.is_ok());
430    }
431
432    #[test]
433    fn test_transform_adds_alias_and_rewrites_having() {
434        let mut transformer = HavingAliasTransformer::new();
435
436        let count_expr = SqlExpression::FunctionCall {
437            name: "COUNT".to_string(),
438            args: vec![SqlExpression::Column(ColumnRef {
439                name: "*".to_string(),
440                quote_style: QuoteStyle::None,
441                table_prefix: None,
442            })],
443            distinct: false,
444        };
445
446        let stmt = SelectStatement {
447            select_items: vec![SelectItem::Expression {
448                expr: count_expr.clone(),
449                alias: String::new(), // No alias initially
450                leading_comments: Vec::new(),
451                trailing_comment: None,
452            }],
453            having: Some(SqlExpression::BinaryOp {
454                left: Box::new(count_expr.clone()),
455                op: ">".to_string(),
456                right: Box::new(SqlExpression::NumberLiteral("5".to_string())),
457            }),
458            ..Default::default()
459        };
460
461        let result = transformer.transform(stmt).unwrap();
462
463        // Check that alias was added to SELECT
464        if let SelectItem::Expression { alias, .. } = &result.select_items[0] {
465            assert_eq!(alias, "__agg_1");
466        } else {
467            panic!("Expected Expression select item");
468        }
469
470        // Check that HAVING was rewritten to use alias
471        if let Some(SqlExpression::BinaryOp { left, .. }) = &result.having {
472            match left.as_ref() {
473                SqlExpression::Column(col_ref) => {
474                    assert_eq!(col_ref.name, "__agg_1");
475                }
476                _ => panic!("Expected column reference in HAVING, got: {:?}", left),
477            }
478        } else {
479            panic!("Expected BinaryOp in HAVING");
480        }
481    }
482}