Skip to main content

sql_cli/query_plan/
having_alias_transformer.rs

1//! HAVING clause auto-aliasing transformer
2//!
3//! This transformer automatically adds aliases to aggregate functions in SELECT
4//! clauses and rewrites HAVING clauses to use those aliases instead of the
5//! aggregate function expressions.
6//!
7//! # Problem
8//!
9//! Users often write queries like:
10//! ```sql
11//! SELECT region, COUNT(*) FROM sales GROUP BY region HAVING COUNT(*) > 5
12//! ```
13//!
14//! This fails because the executor can't evaluate `COUNT(*)` in the HAVING
15//! clause - it needs a column reference.
16//!
17//! # Solution
18//!
19//! The transformer rewrites to:
20//! ```sql
21//! SELECT region, COUNT(*) as __agg_1 FROM sales GROUP BY region HAVING __agg_1 > 5
22//! ```
23//!
24//! # Algorithm
25//!
26//! 1. Find all aggregate functions in SELECT clause
27//! 2. For each aggregate without an explicit alias, generate one (__agg_N)
28//! 3. Scan HAVING clause for matching aggregate expressions
29//! 4. Replace aggregate expressions with column references to the aliases
30
31use crate::query_plan::pipeline::ASTTransformer;
32use crate::sql::parser::ast::{ColumnRef, QuoteStyle, SelectItem, SelectStatement, SqlExpression};
33use anyhow::Result;
34use std::collections::HashMap;
35use tracing::debug;
36
37/// Prefix used for aggregates promoted from HAVING into SELECT.
38/// Columns with this prefix are hidden from the final output.
39pub const HIDDEN_AGG_PREFIX: &str = "__hidden_agg_";
40
41/// Transformer that adds aliases to aggregates and rewrites HAVING clauses
42pub struct HavingAliasTransformer {
43    /// Counter for generating unique alias names
44    alias_counter: usize,
45    /// Counter for HAVING-promoted (hidden) aggregates
46    hidden_counter: usize,
47}
48
49impl HavingAliasTransformer {
50    pub fn new() -> Self {
51        Self {
52            alias_counter: 0,
53            hidden_counter: 0,
54        }
55    }
56
57    /// Check if an expression is an aggregate function
58    fn is_aggregate_function(expr: &SqlExpression) -> bool {
59        matches!(
60            expr,
61            SqlExpression::FunctionCall { name, .. }
62                if matches!(
63                    name.to_uppercase().as_str(),
64                    "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "COUNT_DISTINCT"
65                )
66        )
67    }
68
69    /// Generate a unique alias name
70    fn generate_alias(&mut self) -> String {
71        self.alias_counter += 1;
72        format!("__agg_{}", self.alias_counter)
73    }
74
75    /// Normalize an aggregate expression to a canonical form for comparison
76    fn normalize_aggregate_expr(expr: &SqlExpression) -> String {
77        match expr {
78            SqlExpression::FunctionCall { name, args, .. } => {
79                let args_str = args
80                    .iter()
81                    .map(|arg| match arg {
82                        SqlExpression::Column(col_ref) => {
83                            format!("{}", col_ref.name)
84                        }
85                        SqlExpression::StringLiteral(s) => format!("'{}'", s),
86                        SqlExpression::NumberLiteral(n) => n.clone(),
87                        _ => format!("{:?}", arg), // Fallback for complex args
88                    })
89                    .collect::<Vec<_>>()
90                    .join(",");
91                format!("{}({})", name.to_uppercase(), args_str)
92            }
93            _ => format!("{:?}", expr),
94        }
95    }
96
97    /// Extract aggregate functions from SELECT clause and ensure they have aliases
98    fn ensure_aggregate_aliases(
99        &mut self,
100        select_items: &mut Vec<SelectItem>,
101    ) -> HashMap<String, String> {
102        let mut aggregate_map = HashMap::new();
103
104        for item in select_items.iter_mut() {
105            if let SelectItem::Expression { expr, alias, .. } = item {
106                if Self::is_aggregate_function(expr) {
107                    // Generate alias if none exists
108                    if alias.is_empty() {
109                        *alias = self.generate_alias();
110                        debug!(
111                            "Generated alias '{}' for aggregate: {}",
112                            alias,
113                            Self::normalize_aggregate_expr(expr)
114                        );
115                    }
116
117                    // Map normalized expression to alias
118                    let normalized = Self::normalize_aggregate_expr(expr);
119                    aggregate_map.insert(normalized, alias.clone());
120                }
121            }
122        }
123
124        aggregate_map
125    }
126
127    /// Generate a unique hidden alias name for aggregates promoted from HAVING
128    fn generate_hidden_alias(&mut self) -> String {
129        self.hidden_counter += 1;
130        format!("{}{}", HIDDEN_AGG_PREFIX, self.hidden_counter)
131    }
132
133    /// Collect all aggregate function calls from a HAVING expression
134    fn collect_aggregates_in_having(expr: &SqlExpression, found: &mut Vec<SqlExpression>) {
135        match expr {
136            SqlExpression::FunctionCall { args, .. } if Self::is_aggregate_function(expr) => {
137                found.push(expr.clone());
138                // Don't recurse into aggregate args — nested aggregates aren't supported anyway
139                let _ = args;
140            }
141            SqlExpression::BinaryOp { left, right, .. } => {
142                Self::collect_aggregates_in_having(left, found);
143                Self::collect_aggregates_in_having(right, found);
144            }
145            SqlExpression::Not { expr } => {
146                Self::collect_aggregates_in_having(expr, found);
147            }
148            SqlExpression::FunctionCall { args, .. } => {
149                // Non-aggregate function — check its arguments for aggregates
150                for arg in args {
151                    Self::collect_aggregates_in_having(arg, found);
152                }
153            }
154            _ => {}
155        }
156    }
157
158    /// Promote aggregates in HAVING that aren't already in SELECT into hidden
159    /// SELECT items. Returns updated aggregate_map with new entries.
160    fn promote_having_aggregates(
161        &mut self,
162        having_expr: &SqlExpression,
163        select_items: &mut Vec<SelectItem>,
164        aggregate_map: &mut HashMap<String, String>,
165    ) {
166        let mut having_aggs = Vec::new();
167        Self::collect_aggregates_in_having(having_expr, &mut having_aggs);
168
169        for agg in having_aggs {
170            let normalized = Self::normalize_aggregate_expr(&agg);
171            if aggregate_map.contains_key(&normalized) {
172                continue; // Already in SELECT
173            }
174
175            let hidden_alias = self.generate_hidden_alias();
176            debug!(
177                "Promoting HAVING aggregate {} as hidden SELECT item '{}'",
178                normalized, hidden_alias
179            );
180
181            select_items.push(SelectItem::Expression {
182                expr: agg,
183                alias: hidden_alias.clone(),
184                leading_comments: Vec::new(),
185                trailing_comment: None,
186            });
187
188            aggregate_map.insert(normalized, hidden_alias);
189        }
190    }
191
192    /// Rewrite a HAVING expression to use aliases instead of aggregates
193    fn rewrite_having_expression(
194        expr: &SqlExpression,
195        aggregate_map: &HashMap<String, String>,
196    ) -> SqlExpression {
197        match expr {
198            SqlExpression::FunctionCall { .. } if Self::is_aggregate_function(expr) => {
199                let normalized = Self::normalize_aggregate_expr(expr);
200                if let Some(alias) = aggregate_map.get(&normalized) {
201                    debug!(
202                        "Rewriting aggregate {} to column reference {}",
203                        normalized, alias
204                    );
205                    SqlExpression::Column(ColumnRef {
206                        name: alias.clone(),
207                        quote_style: QuoteStyle::None,
208                        table_prefix: None,
209                    })
210                } else {
211                    // Aggregate not found in SELECT - leave as is (will fail later with clear error)
212                    expr.clone()
213                }
214            }
215            SqlExpression::BinaryOp { left, op, right } => SqlExpression::BinaryOp {
216                left: Box::new(Self::rewrite_having_expression(left, aggregate_map)),
217                op: op.clone(),
218                right: Box::new(Self::rewrite_having_expression(right, aggregate_map)),
219            },
220            SqlExpression::Not { expr } => SqlExpression::Not {
221                expr: Box::new(Self::rewrite_having_expression(expr, aggregate_map)),
222            },
223            SqlExpression::FunctionCall {
224                name,
225                args,
226                distinct,
227            } => {
228                // Non-aggregate function — recurse into args
229                SqlExpression::FunctionCall {
230                    name: name.clone(),
231                    args: args
232                        .iter()
233                        .map(|a| Self::rewrite_having_expression(a, aggregate_map))
234                        .collect(),
235                    distinct: *distinct,
236                }
237            }
238            // For other expressions, return as-is
239            _ => expr.clone(),
240        }
241    }
242}
243
244impl Default for HavingAliasTransformer {
245    fn default() -> Self {
246        Self::new()
247    }
248}
249
250impl ASTTransformer for HavingAliasTransformer {
251    fn name(&self) -> &str {
252        "HavingAliasTransformer"
253    }
254
255    fn description(&self) -> &str {
256        "Adds aliases to aggregate functions and rewrites HAVING clauses to use them"
257    }
258
259    fn transform(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
260        // Only process if there's a HAVING clause
261        if stmt.having.is_none() {
262            return Ok(stmt);
263        }
264
265        // Step 1: Ensure all aggregates in SELECT have aliases and build mapping
266        let mut aggregate_map = self.ensure_aggregate_aliases(&mut stmt.select_items);
267
268        // Step 1b: Promote any HAVING-only aggregates into SELECT with hidden aliases
269        // This allows HAVING to reference aggregates not already in SELECT
270        if let Some(ref having_expr) = stmt.having {
271            self.promote_having_aggregates(having_expr, &mut stmt.select_items, &mut aggregate_map);
272        }
273
274        if aggregate_map.is_empty() {
275            // No aggregates found, nothing to do
276            return Ok(stmt);
277        }
278
279        // Step 2: Rewrite HAVING clause to use aliases
280        if let Some(having_expr) = stmt.having.take() {
281            let rewritten = Self::rewrite_having_expression(&having_expr, &aggregate_map);
282
283            // Only set if something changed
284            if format!("{:?}", having_expr) != format!("{:?}", rewritten) {
285                debug!(
286                    "Rewrote HAVING clause with {} aggregate alias(es)",
287                    aggregate_map.len()
288                );
289                stmt.having = Some(rewritten);
290            } else {
291                stmt.having = Some(having_expr);
292            }
293        }
294
295        Ok(stmt)
296    }
297
298    fn begin(&mut self) -> Result<()> {
299        // Reset counters for each query
300        self.alias_counter = 0;
301        self.hidden_counter = 0;
302        Ok(())
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn test_is_aggregate_function() {
312        let count_expr = SqlExpression::FunctionCall {
313            name: "COUNT".to_string(),
314            args: vec![SqlExpression::Column(ColumnRef {
315                name: "*".to_string(),
316                quote_style: QuoteStyle::None,
317                table_prefix: None,
318            })],
319            distinct: false,
320        };
321        assert!(HavingAliasTransformer::is_aggregate_function(&count_expr));
322
323        let sum_expr = SqlExpression::FunctionCall {
324            name: "SUM".to_string(),
325            args: vec![SqlExpression::Column(ColumnRef {
326                name: "amount".to_string(),
327                quote_style: QuoteStyle::None,
328                table_prefix: None,
329            })],
330            distinct: false,
331        };
332        assert!(HavingAliasTransformer::is_aggregate_function(&sum_expr));
333
334        let non_agg = SqlExpression::FunctionCall {
335            name: "UPPER".to_string(),
336            args: vec![],
337            distinct: false,
338        };
339        assert!(!HavingAliasTransformer::is_aggregate_function(&non_agg));
340    }
341
342    #[test]
343    fn test_normalize_aggregate_expr() {
344        let count_star = SqlExpression::FunctionCall {
345            name: "count".to_string(),
346            args: vec![SqlExpression::Column(ColumnRef {
347                name: "*".to_string(),
348                quote_style: QuoteStyle::None,
349                table_prefix: None,
350            })],
351            distinct: false,
352        };
353        assert_eq!(
354            HavingAliasTransformer::normalize_aggregate_expr(&count_star),
355            "COUNT(*)"
356        );
357
358        let sum_amount = SqlExpression::FunctionCall {
359            name: "SUM".to_string(),
360            args: vec![SqlExpression::Column(ColumnRef {
361                name: "amount".to_string(),
362                quote_style: QuoteStyle::None,
363                table_prefix: None,
364            })],
365            distinct: false,
366        };
367        assert_eq!(
368            HavingAliasTransformer::normalize_aggregate_expr(&sum_amount),
369            "SUM(amount)"
370        );
371    }
372
373    #[test]
374    fn test_generate_alias() {
375        let mut transformer = HavingAliasTransformer::new();
376        assert_eq!(transformer.generate_alias(), "__agg_1");
377        assert_eq!(transformer.generate_alias(), "__agg_2");
378        assert_eq!(transformer.generate_alias(), "__agg_3");
379    }
380
381    #[test]
382    fn test_transform_with_no_having() {
383        let mut transformer = HavingAliasTransformer::new();
384        let stmt = SelectStatement {
385            having: None,
386            ..Default::default()
387        };
388
389        let result = transformer.transform(stmt);
390        assert!(result.is_ok());
391    }
392
393    #[test]
394    fn test_transform_adds_alias_and_rewrites_having() {
395        let mut transformer = HavingAliasTransformer::new();
396
397        let count_expr = SqlExpression::FunctionCall {
398            name: "COUNT".to_string(),
399            args: vec![SqlExpression::Column(ColumnRef {
400                name: "*".to_string(),
401                quote_style: QuoteStyle::None,
402                table_prefix: None,
403            })],
404            distinct: false,
405        };
406
407        let stmt = SelectStatement {
408            select_items: vec![SelectItem::Expression {
409                expr: count_expr.clone(),
410                alias: String::new(), // No alias initially
411                leading_comments: Vec::new(),
412                trailing_comment: None,
413            }],
414            having: Some(SqlExpression::BinaryOp {
415                left: Box::new(count_expr.clone()),
416                op: ">".to_string(),
417                right: Box::new(SqlExpression::NumberLiteral("5".to_string())),
418            }),
419            ..Default::default()
420        };
421
422        let result = transformer.transform(stmt).unwrap();
423
424        // Check that alias was added to SELECT
425        if let SelectItem::Expression { alias, .. } = &result.select_items[0] {
426            assert_eq!(alias, "__agg_1");
427        } else {
428            panic!("Expected Expression select item");
429        }
430
431        // Check that HAVING was rewritten to use alias
432        if let Some(SqlExpression::BinaryOp { left, .. }) = &result.having {
433            match left.as_ref() {
434                SqlExpression::Column(col_ref) => {
435                    assert_eq!(col_ref.name, "__agg_1");
436                }
437                _ => panic!("Expected column reference in HAVING, got: {:?}", left),
438            }
439        } else {
440            panic!("Expected BinaryOp in HAVING");
441        }
442    }
443}