1use crate::query_plan::pipeline::ASTTransformer;
32use crate::sql::parser::ast::{ColumnRef, QuoteStyle, SelectItem, SelectStatement, SqlExpression};
33use anyhow::Result;
34use std::collections::HashMap;
35use tracing::debug;
36
37pub const HIDDEN_AGG_PREFIX: &str = "__hidden_agg_";
40
41pub struct HavingAliasTransformer {
43 alias_counter: usize,
45 hidden_counter: usize,
47}
48
49impl HavingAliasTransformer {
50 pub fn new() -> Self {
51 Self {
52 alias_counter: 0,
53 hidden_counter: 0,
54 }
55 }
56
57 fn is_aggregate_function(expr: &SqlExpression) -> bool {
59 matches!(
60 expr,
61 SqlExpression::FunctionCall { name, .. }
62 if matches!(
63 name.to_uppercase().as_str(),
64 "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "COUNT_DISTINCT"
65 )
66 )
67 }
68
69 fn generate_alias(&mut self) -> String {
71 self.alias_counter += 1;
72 format!("__agg_{}", self.alias_counter)
73 }
74
75 fn normalize_aggregate_expr(expr: &SqlExpression) -> String {
77 match expr {
78 SqlExpression::FunctionCall { name, args, .. } => {
79 let args_str = args
80 .iter()
81 .map(|arg| match arg {
82 SqlExpression::Column(col_ref) => {
83 format!("{}", col_ref.name)
84 }
85 SqlExpression::StringLiteral(s) => format!("'{}'", s),
86 SqlExpression::NumberLiteral(n) => n.clone(),
87 _ => format!("{:?}", arg), })
89 .collect::<Vec<_>>()
90 .join(",");
91 format!("{}({})", name.to_uppercase(), args_str)
92 }
93 _ => format!("{:?}", expr),
94 }
95 }
96
97 fn ensure_aggregate_aliases(
99 &mut self,
100 select_items: &mut Vec<SelectItem>,
101 ) -> HashMap<String, String> {
102 let mut aggregate_map = HashMap::new();
103
104 for item in select_items.iter_mut() {
105 if let SelectItem::Expression { expr, alias, .. } = item {
106 if Self::is_aggregate_function(expr) {
107 if alias.is_empty() {
109 *alias = self.generate_alias();
110 debug!(
111 "Generated alias '{}' for aggregate: {}",
112 alias,
113 Self::normalize_aggregate_expr(expr)
114 );
115 }
116
117 let normalized = Self::normalize_aggregate_expr(expr);
119 aggregate_map.insert(normalized, alias.clone());
120 }
121 }
122 }
123
124 aggregate_map
125 }
126
127 fn generate_hidden_alias(&mut self) -> String {
129 self.hidden_counter += 1;
130 format!("{}{}", HIDDEN_AGG_PREFIX, self.hidden_counter)
131 }
132
133 fn collect_aggregates_in_having(expr: &SqlExpression, found: &mut Vec<SqlExpression>) {
135 match expr {
136 SqlExpression::FunctionCall { args, .. } if Self::is_aggregate_function(expr) => {
137 found.push(expr.clone());
138 let _ = args;
140 }
141 SqlExpression::BinaryOp { left, right, .. } => {
142 Self::collect_aggregates_in_having(left, found);
143 Self::collect_aggregates_in_having(right, found);
144 }
145 SqlExpression::Not { expr } => {
146 Self::collect_aggregates_in_having(expr, found);
147 }
148 SqlExpression::FunctionCall { args, .. } => {
149 for arg in args {
151 Self::collect_aggregates_in_having(arg, found);
152 }
153 }
154 _ => {}
155 }
156 }
157
158 fn promote_having_aggregates(
161 &mut self,
162 having_expr: &SqlExpression,
163 select_items: &mut Vec<SelectItem>,
164 aggregate_map: &mut HashMap<String, String>,
165 ) {
166 let mut having_aggs = Vec::new();
167 Self::collect_aggregates_in_having(having_expr, &mut having_aggs);
168
169 for agg in having_aggs {
170 let normalized = Self::normalize_aggregate_expr(&agg);
171 if aggregate_map.contains_key(&normalized) {
172 continue; }
174
175 let hidden_alias = self.generate_hidden_alias();
176 debug!(
177 "Promoting HAVING aggregate {} as hidden SELECT item '{}'",
178 normalized, hidden_alias
179 );
180
181 select_items.push(SelectItem::Expression {
182 expr: agg,
183 alias: hidden_alias.clone(),
184 leading_comments: Vec::new(),
185 trailing_comment: None,
186 });
187
188 aggregate_map.insert(normalized, hidden_alias);
189 }
190 }
191
192 fn rewrite_having_expression(
194 expr: &SqlExpression,
195 aggregate_map: &HashMap<String, String>,
196 ) -> SqlExpression {
197 match expr {
198 SqlExpression::FunctionCall { .. } if Self::is_aggregate_function(expr) => {
199 let normalized = Self::normalize_aggregate_expr(expr);
200 if let Some(alias) = aggregate_map.get(&normalized) {
201 debug!(
202 "Rewriting aggregate {} to column reference {}",
203 normalized, alias
204 );
205 SqlExpression::Column(ColumnRef {
206 name: alias.clone(),
207 quote_style: QuoteStyle::None,
208 table_prefix: None,
209 })
210 } else {
211 expr.clone()
213 }
214 }
215 SqlExpression::BinaryOp { left, op, right } => SqlExpression::BinaryOp {
216 left: Box::new(Self::rewrite_having_expression(left, aggregate_map)),
217 op: op.clone(),
218 right: Box::new(Self::rewrite_having_expression(right, aggregate_map)),
219 },
220 SqlExpression::Not { expr } => SqlExpression::Not {
221 expr: Box::new(Self::rewrite_having_expression(expr, aggregate_map)),
222 },
223 SqlExpression::FunctionCall {
224 name,
225 args,
226 distinct,
227 } => {
228 SqlExpression::FunctionCall {
230 name: name.clone(),
231 args: args
232 .iter()
233 .map(|a| Self::rewrite_having_expression(a, aggregate_map))
234 .collect(),
235 distinct: *distinct,
236 }
237 }
238 _ => expr.clone(),
240 }
241 }
242}
243
244impl Default for HavingAliasTransformer {
245 fn default() -> Self {
246 Self::new()
247 }
248}
249
250impl ASTTransformer for HavingAliasTransformer {
251 fn name(&self) -> &str {
252 "HavingAliasTransformer"
253 }
254
255 fn description(&self) -> &str {
256 "Adds aliases to aggregate functions and rewrites HAVING clauses to use them"
257 }
258
259 fn transform(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
260 if stmt.having.is_none() {
262 return Ok(stmt);
263 }
264
265 let mut aggregate_map = self.ensure_aggregate_aliases(&mut stmt.select_items);
267
268 if let Some(ref having_expr) = stmt.having {
271 self.promote_having_aggregates(having_expr, &mut stmt.select_items, &mut aggregate_map);
272 }
273
274 if aggregate_map.is_empty() {
275 return Ok(stmt);
277 }
278
279 if let Some(having_expr) = stmt.having.take() {
281 let rewritten = Self::rewrite_having_expression(&having_expr, &aggregate_map);
282
283 if format!("{:?}", having_expr) != format!("{:?}", rewritten) {
285 debug!(
286 "Rewrote HAVING clause with {} aggregate alias(es)",
287 aggregate_map.len()
288 );
289 stmt.having = Some(rewritten);
290 } else {
291 stmt.having = Some(having_expr);
292 }
293 }
294
295 Ok(stmt)
296 }
297
298 fn begin(&mut self) -> Result<()> {
299 self.alias_counter = 0;
301 self.hidden_counter = 0;
302 Ok(())
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_is_aggregate_function() {
312 let count_expr = SqlExpression::FunctionCall {
313 name: "COUNT".to_string(),
314 args: vec![SqlExpression::Column(ColumnRef {
315 name: "*".to_string(),
316 quote_style: QuoteStyle::None,
317 table_prefix: None,
318 })],
319 distinct: false,
320 };
321 assert!(HavingAliasTransformer::is_aggregate_function(&count_expr));
322
323 let sum_expr = SqlExpression::FunctionCall {
324 name: "SUM".to_string(),
325 args: vec![SqlExpression::Column(ColumnRef {
326 name: "amount".to_string(),
327 quote_style: QuoteStyle::None,
328 table_prefix: None,
329 })],
330 distinct: false,
331 };
332 assert!(HavingAliasTransformer::is_aggregate_function(&sum_expr));
333
334 let non_agg = SqlExpression::FunctionCall {
335 name: "UPPER".to_string(),
336 args: vec![],
337 distinct: false,
338 };
339 assert!(!HavingAliasTransformer::is_aggregate_function(&non_agg));
340 }
341
342 #[test]
343 fn test_normalize_aggregate_expr() {
344 let count_star = SqlExpression::FunctionCall {
345 name: "count".to_string(),
346 args: vec![SqlExpression::Column(ColumnRef {
347 name: "*".to_string(),
348 quote_style: QuoteStyle::None,
349 table_prefix: None,
350 })],
351 distinct: false,
352 };
353 assert_eq!(
354 HavingAliasTransformer::normalize_aggregate_expr(&count_star),
355 "COUNT(*)"
356 );
357
358 let sum_amount = SqlExpression::FunctionCall {
359 name: "SUM".to_string(),
360 args: vec![SqlExpression::Column(ColumnRef {
361 name: "amount".to_string(),
362 quote_style: QuoteStyle::None,
363 table_prefix: None,
364 })],
365 distinct: false,
366 };
367 assert_eq!(
368 HavingAliasTransformer::normalize_aggregate_expr(&sum_amount),
369 "SUM(amount)"
370 );
371 }
372
373 #[test]
374 fn test_generate_alias() {
375 let mut transformer = HavingAliasTransformer::new();
376 assert_eq!(transformer.generate_alias(), "__agg_1");
377 assert_eq!(transformer.generate_alias(), "__agg_2");
378 assert_eq!(transformer.generate_alias(), "__agg_3");
379 }
380
381 #[test]
382 fn test_transform_with_no_having() {
383 let mut transformer = HavingAliasTransformer::new();
384 let stmt = SelectStatement {
385 having: None,
386 ..Default::default()
387 };
388
389 let result = transformer.transform(stmt);
390 assert!(result.is_ok());
391 }
392
393 #[test]
394 fn test_transform_adds_alias_and_rewrites_having() {
395 let mut transformer = HavingAliasTransformer::new();
396
397 let count_expr = SqlExpression::FunctionCall {
398 name: "COUNT".to_string(),
399 args: vec![SqlExpression::Column(ColumnRef {
400 name: "*".to_string(),
401 quote_style: QuoteStyle::None,
402 table_prefix: None,
403 })],
404 distinct: false,
405 };
406
407 let stmt = SelectStatement {
408 select_items: vec![SelectItem::Expression {
409 expr: count_expr.clone(),
410 alias: String::new(), leading_comments: Vec::new(),
412 trailing_comment: None,
413 }],
414 having: Some(SqlExpression::BinaryOp {
415 left: Box::new(count_expr.clone()),
416 op: ">".to_string(),
417 right: Box::new(SqlExpression::NumberLiteral("5".to_string())),
418 }),
419 ..Default::default()
420 };
421
422 let result = transformer.transform(stmt).unwrap();
423
424 if let SelectItem::Expression { alias, .. } = &result.select_items[0] {
426 assert_eq!(alias, "__agg_1");
427 } else {
428 panic!("Expected Expression select item");
429 }
430
431 if let Some(SqlExpression::BinaryOp { left, .. }) = &result.having {
433 match left.as_ref() {
434 SqlExpression::Column(col_ref) => {
435 assert_eq!(col_ref.name, "__agg_1");
436 }
437 _ => panic!("Expected column reference in HAVING, got: {:?}", left),
438 }
439 } else {
440 panic!("Expected BinaryOp in HAVING");
441 }
442 }
443}