1use crate::query_plan::pipeline::ASTTransformer;
32use crate::sql::parser::ast::{ColumnRef, QuoteStyle, SelectItem, SelectStatement, SqlExpression};
33use anyhow::Result;
34use std::collections::HashMap;
35use tracing::debug;
36
37pub struct HavingAliasTransformer {
39 alias_counter: usize,
41}
42
43impl HavingAliasTransformer {
44 pub fn new() -> Self {
45 Self { alias_counter: 0 }
46 }
47
48 fn is_aggregate_function(expr: &SqlExpression) -> bool {
50 matches!(
51 expr,
52 SqlExpression::FunctionCall { name, .. }
53 if matches!(
54 name.to_uppercase().as_str(),
55 "COUNT" | "SUM" | "AVG" | "MIN" | "MAX" | "COUNT_DISTINCT"
56 )
57 )
58 }
59
60 fn generate_alias(&mut self) -> String {
62 self.alias_counter += 1;
63 format!("__agg_{}", self.alias_counter)
64 }
65
66 fn normalize_aggregate_expr(expr: &SqlExpression) -> String {
68 match expr {
69 SqlExpression::FunctionCall { name, args, .. } => {
70 let args_str = args
71 .iter()
72 .map(|arg| match arg {
73 SqlExpression::Column(col_ref) => {
74 format!("{}", col_ref.name)
75 }
76 SqlExpression::StringLiteral(s) => format!("'{}'", s),
77 SqlExpression::NumberLiteral(n) => n.clone(),
78 _ => format!("{:?}", arg), })
80 .collect::<Vec<_>>()
81 .join(",");
82 format!("{}({})", name.to_uppercase(), args_str)
83 }
84 _ => format!("{:?}", expr),
85 }
86 }
87
88 fn ensure_aggregate_aliases(
90 &mut self,
91 select_items: &mut Vec<SelectItem>,
92 ) -> HashMap<String, String> {
93 let mut aggregate_map = HashMap::new();
94
95 for item in select_items.iter_mut() {
96 if let SelectItem::Expression { expr, alias, .. } = item {
97 if Self::is_aggregate_function(expr) {
98 if alias.is_empty() {
100 *alias = self.generate_alias();
101 debug!(
102 "Generated alias '{}' for aggregate: {}",
103 alias,
104 Self::normalize_aggregate_expr(expr)
105 );
106 }
107
108 let normalized = Self::normalize_aggregate_expr(expr);
110 aggregate_map.insert(normalized, alias.clone());
111 }
112 }
113 }
114
115 aggregate_map
116 }
117
118 fn rewrite_having_expression(
120 expr: &SqlExpression,
121 aggregate_map: &HashMap<String, String>,
122 ) -> SqlExpression {
123 match expr {
124 SqlExpression::FunctionCall { .. } if Self::is_aggregate_function(expr) => {
125 let normalized = Self::normalize_aggregate_expr(expr);
126 if let Some(alias) = aggregate_map.get(&normalized) {
127 debug!(
128 "Rewriting aggregate {} to column reference {}",
129 normalized, alias
130 );
131 SqlExpression::Column(ColumnRef {
132 name: alias.clone(),
133 quote_style: QuoteStyle::None,
134 table_prefix: None,
135 })
136 } else {
137 expr.clone()
139 }
140 }
141 SqlExpression::BinaryOp { left, op, right } => SqlExpression::BinaryOp {
142 left: Box::new(Self::rewrite_having_expression(left, aggregate_map)),
143 op: op.clone(),
144 right: Box::new(Self::rewrite_having_expression(right, aggregate_map)),
145 },
146 _ => expr.clone(),
148 }
149 }
150}
151
152impl Default for HavingAliasTransformer {
153 fn default() -> Self {
154 Self::new()
155 }
156}
157
158impl ASTTransformer for HavingAliasTransformer {
159 fn name(&self) -> &str {
160 "HavingAliasTransformer"
161 }
162
163 fn description(&self) -> &str {
164 "Adds aliases to aggregate functions and rewrites HAVING clauses to use them"
165 }
166
167 fn transform(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
168 if stmt.having.is_none() {
170 return Ok(stmt);
171 }
172
173 let aggregate_map = self.ensure_aggregate_aliases(&mut stmt.select_items);
175
176 if aggregate_map.is_empty() {
177 return Ok(stmt);
179 }
180
181 if let Some(having_expr) = stmt.having.take() {
183 let rewritten = Self::rewrite_having_expression(&having_expr, &aggregate_map);
184
185 if format!("{:?}", having_expr) != format!("{:?}", rewritten) {
187 debug!(
188 "Rewrote HAVING clause with {} aggregate alias(es)",
189 aggregate_map.len()
190 );
191 stmt.having = Some(rewritten);
192 } else {
193 stmt.having = Some(having_expr);
194 }
195 }
196
197 Ok(stmt)
198 }
199
200 fn begin(&mut self) -> Result<()> {
201 self.alias_counter = 0;
203 Ok(())
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn test_is_aggregate_function() {
213 let count_expr = SqlExpression::FunctionCall {
214 name: "COUNT".to_string(),
215 args: vec![SqlExpression::Column(ColumnRef {
216 name: "*".to_string(),
217 quote_style: QuoteStyle::None,
218 table_prefix: None,
219 })],
220 distinct: false,
221 };
222 assert!(HavingAliasTransformer::is_aggregate_function(&count_expr));
223
224 let sum_expr = SqlExpression::FunctionCall {
225 name: "SUM".to_string(),
226 args: vec![SqlExpression::Column(ColumnRef {
227 name: "amount".to_string(),
228 quote_style: QuoteStyle::None,
229 table_prefix: None,
230 })],
231 distinct: false,
232 };
233 assert!(HavingAliasTransformer::is_aggregate_function(&sum_expr));
234
235 let non_agg = SqlExpression::FunctionCall {
236 name: "UPPER".to_string(),
237 args: vec![],
238 distinct: false,
239 };
240 assert!(!HavingAliasTransformer::is_aggregate_function(&non_agg));
241 }
242
243 #[test]
244 fn test_normalize_aggregate_expr() {
245 let count_star = SqlExpression::FunctionCall {
246 name: "count".to_string(),
247 args: vec![SqlExpression::Column(ColumnRef {
248 name: "*".to_string(),
249 quote_style: QuoteStyle::None,
250 table_prefix: None,
251 })],
252 distinct: false,
253 };
254 assert_eq!(
255 HavingAliasTransformer::normalize_aggregate_expr(&count_star),
256 "COUNT(*)"
257 );
258
259 let sum_amount = SqlExpression::FunctionCall {
260 name: "SUM".to_string(),
261 args: vec![SqlExpression::Column(ColumnRef {
262 name: "amount".to_string(),
263 quote_style: QuoteStyle::None,
264 table_prefix: None,
265 })],
266 distinct: false,
267 };
268 assert_eq!(
269 HavingAliasTransformer::normalize_aggregate_expr(&sum_amount),
270 "SUM(amount)"
271 );
272 }
273
274 #[test]
275 fn test_generate_alias() {
276 let mut transformer = HavingAliasTransformer::new();
277 assert_eq!(transformer.generate_alias(), "__agg_1");
278 assert_eq!(transformer.generate_alias(), "__agg_2");
279 assert_eq!(transformer.generate_alias(), "__agg_3");
280 }
281
282 #[test]
283 fn test_transform_with_no_having() {
284 let mut transformer = HavingAliasTransformer::new();
285 let stmt = SelectStatement {
286 having: None,
287 ..Default::default()
288 };
289
290 let result = transformer.transform(stmt);
291 assert!(result.is_ok());
292 }
293
294 #[test]
295 fn test_transform_adds_alias_and_rewrites_having() {
296 let mut transformer = HavingAliasTransformer::new();
297
298 let count_expr = SqlExpression::FunctionCall {
299 name: "COUNT".to_string(),
300 args: vec![SqlExpression::Column(ColumnRef {
301 name: "*".to_string(),
302 quote_style: QuoteStyle::None,
303 table_prefix: None,
304 })],
305 distinct: false,
306 };
307
308 let stmt = SelectStatement {
309 select_items: vec![SelectItem::Expression {
310 expr: count_expr.clone(),
311 alias: String::new(), leading_comments: Vec::new(),
313 trailing_comment: None,
314 }],
315 having: Some(SqlExpression::BinaryOp {
316 left: Box::new(count_expr.clone()),
317 op: ">".to_string(),
318 right: Box::new(SqlExpression::NumberLiteral("5".to_string())),
319 }),
320 ..Default::default()
321 };
322
323 let result = transformer.transform(stmt).unwrap();
324
325 if let SelectItem::Expression { alias, .. } = &result.select_items[0] {
327 assert_eq!(alias, "__agg_1");
328 } else {
329 panic!("Expected Expression select item");
330 }
331
332 if let Some(SqlExpression::BinaryOp { left, .. }) = &result.having {
334 match left.as_ref() {
335 SqlExpression::Column(col_ref) => {
336 assert_eq!(col_ref.name, "__agg_1");
337 }
338 _ => panic!("Expected column reference in HAVING, got: {:?}", left),
339 }
340 } else {
341 panic!("Expected BinaryOp in HAVING");
342 }
343 }
344}