1use crate::query_plan::pipeline::ASTTransformer;
32use crate::sql::parser::ast::{
33 CTEType, ColumnRef, QuoteStyle, SelectItem, SelectStatement, SqlExpression, TableSource,
34};
35use anyhow::Result;
36use std::collections::HashMap;
37use tracing::debug;
38
39pub const HIDDEN_AGG_PREFIX: &str = "__hidden_agg_";
42
43pub struct HavingAliasTransformer {
45 alias_counter: usize,
47 hidden_counter: usize,
49}
50
51impl HavingAliasTransformer {
52 pub fn new() -> Self {
53 Self {
54 alias_counter: 0,
55 hidden_counter: 0,
56 }
57 }
58
59 fn is_aggregate_function(expr: &SqlExpression) -> bool {
61 matches!(
62 expr,
63 SqlExpression::FunctionCall { name, .. }
64 if matches!(
65 name.to_uppercase().as_str(),
66 "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "COUNT_DISTINCT"
67 )
68 )
69 }
70
71 fn generate_alias(&mut self) -> String {
73 self.alias_counter += 1;
74 format!("__agg_{}", self.alias_counter)
75 }
76
77 fn normalize_aggregate_expr(expr: &SqlExpression) -> String {
79 match expr {
80 SqlExpression::FunctionCall { name, args, .. } => {
81 let args_str = args
82 .iter()
83 .map(|arg| match arg {
84 SqlExpression::Column(col_ref) => {
85 format!("{}", col_ref.name)
86 }
87 SqlExpression::StringLiteral(s) => format!("'{}'", s),
88 SqlExpression::NumberLiteral(n) => n.clone(),
89 _ => format!("{:?}", arg), })
91 .collect::<Vec<_>>()
92 .join(",");
93 format!("{}({})", name.to_uppercase(), args_str)
94 }
95 _ => format!("{:?}", expr),
96 }
97 }
98
99 fn ensure_aggregate_aliases(
101 &mut self,
102 select_items: &mut Vec<SelectItem>,
103 ) -> HashMap<String, String> {
104 let mut aggregate_map = HashMap::new();
105
106 for item in select_items.iter_mut() {
107 if let SelectItem::Expression { expr, alias, .. } = item {
108 if Self::is_aggregate_function(expr) {
109 if alias.is_empty() {
111 *alias = self.generate_alias();
112 debug!(
113 "Generated alias '{}' for aggregate: {}",
114 alias,
115 Self::normalize_aggregate_expr(expr)
116 );
117 }
118
119 let normalized = Self::normalize_aggregate_expr(expr);
121 aggregate_map.insert(normalized, alias.clone());
122 }
123 }
124 }
125
126 aggregate_map
127 }
128
129 fn generate_hidden_alias(&mut self) -> String {
131 self.hidden_counter += 1;
132 format!("{}{}", HIDDEN_AGG_PREFIX, self.hidden_counter)
133 }
134
135 fn collect_aggregates_in_having(expr: &SqlExpression, found: &mut Vec<SqlExpression>) {
137 match expr {
138 SqlExpression::FunctionCall { args, .. } if Self::is_aggregate_function(expr) => {
139 found.push(expr.clone());
140 let _ = args;
142 }
143 SqlExpression::BinaryOp { left, right, .. } => {
144 Self::collect_aggregates_in_having(left, found);
145 Self::collect_aggregates_in_having(right, found);
146 }
147 SqlExpression::Not { expr } => {
148 Self::collect_aggregates_in_having(expr, found);
149 }
150 SqlExpression::FunctionCall { args, .. } => {
151 for arg in args {
153 Self::collect_aggregates_in_having(arg, found);
154 }
155 }
156 _ => {}
157 }
158 }
159
160 fn promote_having_aggregates(
163 &mut self,
164 having_expr: &SqlExpression,
165 select_items: &mut Vec<SelectItem>,
166 aggregate_map: &mut HashMap<String, String>,
167 ) {
168 let mut having_aggs = Vec::new();
169 Self::collect_aggregates_in_having(having_expr, &mut having_aggs);
170
171 for agg in having_aggs {
172 let normalized = Self::normalize_aggregate_expr(&agg);
173 if aggregate_map.contains_key(&normalized) {
174 continue; }
176
177 let hidden_alias = self.generate_hidden_alias();
178 debug!(
179 "Promoting HAVING aggregate {} as hidden SELECT item '{}'",
180 normalized, hidden_alias
181 );
182
183 select_items.push(SelectItem::Expression {
184 expr: agg,
185 alias: hidden_alias.clone(),
186 leading_comments: Vec::new(),
187 trailing_comment: None,
188 });
189
190 aggregate_map.insert(normalized, hidden_alias);
191 }
192 }
193
194 fn rewrite_having_expression(
196 expr: &SqlExpression,
197 aggregate_map: &HashMap<String, String>,
198 ) -> SqlExpression {
199 match expr {
200 SqlExpression::FunctionCall { .. } if Self::is_aggregate_function(expr) => {
201 let normalized = Self::normalize_aggregate_expr(expr);
202 if let Some(alias) = aggregate_map.get(&normalized) {
203 debug!(
204 "Rewriting aggregate {} to column reference {}",
205 normalized, alias
206 );
207 SqlExpression::Column(ColumnRef {
208 name: alias.clone(),
209 quote_style: QuoteStyle::None,
210 table_prefix: None,
211 })
212 } else {
213 expr.clone()
215 }
216 }
217 SqlExpression::BinaryOp { left, op, right } => SqlExpression::BinaryOp {
218 left: Box::new(Self::rewrite_having_expression(left, aggregate_map)),
219 op: op.clone(),
220 right: Box::new(Self::rewrite_having_expression(right, aggregate_map)),
221 },
222 SqlExpression::Not { expr } => SqlExpression::Not {
223 expr: Box::new(Self::rewrite_having_expression(expr, aggregate_map)),
224 },
225 SqlExpression::FunctionCall {
226 name,
227 args,
228 distinct,
229 } => {
230 SqlExpression::FunctionCall {
232 name: name.clone(),
233 args: args
234 .iter()
235 .map(|a| Self::rewrite_having_expression(a, aggregate_map))
236 .collect(),
237 distinct: *distinct,
238 }
239 }
240 _ => expr.clone(),
242 }
243 }
244
245 #[allow(deprecated)]
249 fn transform_statement(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
250 for cte in stmt.ctes.iter_mut() {
252 if let CTEType::Standard(ref mut inner) = cte.cte_type {
253 let taken = std::mem::take(inner);
254 *inner = self.transform_statement(taken)?;
255 }
256 }
257
258 if let Some(TableSource::DerivedTable { query, .. }) = stmt.from_source.as_mut() {
260 let taken = std::mem::take(query.as_mut());
261 **query = self.transform_statement(taken)?;
262 }
263
264 if let Some(subq) = stmt.from_subquery.as_mut() {
266 let taken = std::mem::take(subq.as_mut());
267 **subq = self.transform_statement(taken)?;
268 }
269
270 for (_op, rhs) in stmt.set_operations.iter_mut() {
272 let taken = std::mem::take(rhs.as_mut());
273 **rhs = self.transform_statement(taken)?;
274 }
275
276 self.apply_having_rewrite(&mut stmt);
278
279 Ok(stmt)
280 }
281
282 fn apply_having_rewrite(&mut self, stmt: &mut SelectStatement) {
285 if stmt.having.is_none() {
287 return;
288 }
289
290 let mut aggregate_map = self.ensure_aggregate_aliases(&mut stmt.select_items);
292
293 if let Some(ref having_expr) = stmt.having {
295 self.promote_having_aggregates(having_expr, &mut stmt.select_items, &mut aggregate_map);
296 }
297
298 if aggregate_map.is_empty() {
299 return;
300 }
301
302 if let Some(having_expr) = stmt.having.take() {
304 let rewritten = Self::rewrite_having_expression(&having_expr, &aggregate_map);
305 if format!("{:?}", having_expr) != format!("{:?}", rewritten) {
306 debug!(
307 "Rewrote HAVING clause with {} aggregate alias(es)",
308 aggregate_map.len()
309 );
310 stmt.having = Some(rewritten);
311 } else {
312 stmt.having = Some(having_expr);
313 }
314 }
315 }
316}
317
318impl Default for HavingAliasTransformer {
319 fn default() -> Self {
320 Self::new()
321 }
322}
323
324impl ASTTransformer for HavingAliasTransformer {
325 fn name(&self) -> &str {
326 "HavingAliasTransformer"
327 }
328
329 fn description(&self) -> &str {
330 "Adds aliases to aggregate functions and rewrites HAVING clauses to use them"
331 }
332
333 fn transform(&mut self, stmt: SelectStatement) -> Result<SelectStatement> {
334 self.transform_statement(stmt)
335 }
336
337 fn begin(&mut self) -> Result<()> {
338 self.alias_counter = 0;
340 self.hidden_counter = 0;
341 Ok(())
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[test]
350 fn test_is_aggregate_function() {
351 let count_expr = SqlExpression::FunctionCall {
352 name: "COUNT".to_string(),
353 args: vec![SqlExpression::Column(ColumnRef {
354 name: "*".to_string(),
355 quote_style: QuoteStyle::None,
356 table_prefix: None,
357 })],
358 distinct: false,
359 };
360 assert!(HavingAliasTransformer::is_aggregate_function(&count_expr));
361
362 let sum_expr = SqlExpression::FunctionCall {
363 name: "SUM".to_string(),
364 args: vec![SqlExpression::Column(ColumnRef {
365 name: "amount".to_string(),
366 quote_style: QuoteStyle::None,
367 table_prefix: None,
368 })],
369 distinct: false,
370 };
371 assert!(HavingAliasTransformer::is_aggregate_function(&sum_expr));
372
373 let non_agg = SqlExpression::FunctionCall {
374 name: "UPPER".to_string(),
375 args: vec![],
376 distinct: false,
377 };
378 assert!(!HavingAliasTransformer::is_aggregate_function(&non_agg));
379 }
380
381 #[test]
382 fn test_normalize_aggregate_expr() {
383 let count_star = SqlExpression::FunctionCall {
384 name: "count".to_string(),
385 args: vec![SqlExpression::Column(ColumnRef {
386 name: "*".to_string(),
387 quote_style: QuoteStyle::None,
388 table_prefix: None,
389 })],
390 distinct: false,
391 };
392 assert_eq!(
393 HavingAliasTransformer::normalize_aggregate_expr(&count_star),
394 "COUNT(*)"
395 );
396
397 let sum_amount = SqlExpression::FunctionCall {
398 name: "SUM".to_string(),
399 args: vec![SqlExpression::Column(ColumnRef {
400 name: "amount".to_string(),
401 quote_style: QuoteStyle::None,
402 table_prefix: None,
403 })],
404 distinct: false,
405 };
406 assert_eq!(
407 HavingAliasTransformer::normalize_aggregate_expr(&sum_amount),
408 "SUM(amount)"
409 );
410 }
411
412 #[test]
413 fn test_generate_alias() {
414 let mut transformer = HavingAliasTransformer::new();
415 assert_eq!(transformer.generate_alias(), "__agg_1");
416 assert_eq!(transformer.generate_alias(), "__agg_2");
417 assert_eq!(transformer.generate_alias(), "__agg_3");
418 }
419
420 #[test]
421 fn test_transform_with_no_having() {
422 let mut transformer = HavingAliasTransformer::new();
423 let stmt = SelectStatement {
424 having: None,
425 ..Default::default()
426 };
427
428 let result = transformer.transform(stmt);
429 assert!(result.is_ok());
430 }
431
432 #[test]
433 fn test_transform_adds_alias_and_rewrites_having() {
434 let mut transformer = HavingAliasTransformer::new();
435
436 let count_expr = SqlExpression::FunctionCall {
437 name: "COUNT".to_string(),
438 args: vec![SqlExpression::Column(ColumnRef {
439 name: "*".to_string(),
440 quote_style: QuoteStyle::None,
441 table_prefix: None,
442 })],
443 distinct: false,
444 };
445
446 let stmt = SelectStatement {
447 select_items: vec![SelectItem::Expression {
448 expr: count_expr.clone(),
449 alias: String::new(), leading_comments: Vec::new(),
451 trailing_comment: None,
452 }],
453 having: Some(SqlExpression::BinaryOp {
454 left: Box::new(count_expr.clone()),
455 op: ">".to_string(),
456 right: Box::new(SqlExpression::NumberLiteral("5".to_string())),
457 }),
458 ..Default::default()
459 };
460
461 let result = transformer.transform(stmt).unwrap();
462
463 if let SelectItem::Expression { alias, .. } = &result.select_items[0] {
465 assert_eq!(alias, "__agg_1");
466 } else {
467 panic!("Expected Expression select item");
468 }
469
470 if let Some(SqlExpression::BinaryOp { left, .. }) = &result.having {
472 match left.as_ref() {
473 SqlExpression::Column(col_ref) => {
474 assert_eq!(col_ref.name, "__agg_1");
475 }
476 _ => panic!("Expected column reference in HAVING, got: {:?}", left),
477 }
478 } else {
479 panic!("Expected BinaryOp in HAVING");
480 }
481 }
482}