1use crate::query_plan::pipeline::ASTTransformer;
36use crate::sql::parser::ast::{SelectItem, SelectStatement, SqlExpression};
37use anyhow::Result;
38use std::collections::HashMap;
39use tracing::debug;
40
41pub struct WhereAliasExpander {
43 expansions: usize,
45}
46
47impl WhereAliasExpander {
48 pub fn new() -> Self {
49 Self { expansions: 0 }
50 }
51
52 fn extract_aliases(select_items: &[SelectItem]) -> HashMap<String, SqlExpression> {
55 let mut aliases = HashMap::new();
56
57 for item in select_items {
58 if let SelectItem::Expression { expr, alias, .. } = item {
59 if !alias.is_empty() {
60 aliases.insert(alias.clone(), expr.clone());
61 debug!("Found SELECT alias: {} -> {:?}", alias, expr);
62 }
63 }
64 }
65
66 aliases
67 }
68
69 fn expand_expression(
72 expr: &SqlExpression,
73 aliases: &HashMap<String, SqlExpression>,
74 ) -> (SqlExpression, bool) {
75 match expr {
76 SqlExpression::Column(col_ref) => {
78 if col_ref.table_prefix.is_none() {
80 if let Some(alias_expr) = aliases.get(&col_ref.name) {
81 debug!(
82 "Expanding alias '{}' in WHERE to: {:?}",
83 col_ref.name, alias_expr
84 );
85 return (alias_expr.clone(), true);
86 }
87 }
88 (expr.clone(), false)
89 }
90
91 SqlExpression::BinaryOp { left, op, right } => {
93 let (new_left, left_expanded) = Self::expand_expression(left, aliases);
94 let (new_right, right_expanded) = Self::expand_expression(right, aliases);
95 let expanded = left_expanded || right_expanded;
96
97 (
98 SqlExpression::BinaryOp {
99 left: Box::new(new_left),
100 op: op.clone(),
101 right: Box::new(new_right),
102 },
103 expanded,
104 )
105 }
106
107 SqlExpression::Not { expr: inner } => {
109 let (new_expr, expanded) = Self::expand_expression(inner, aliases);
110 (
111 SqlExpression::Not {
112 expr: Box::new(new_expr),
113 },
114 expanded,
115 )
116 }
117
118 SqlExpression::FunctionCall {
120 name,
121 args,
122 distinct,
123 } => {
124 let mut expanded = false;
125 let new_args: Vec<SqlExpression> = args
126 .iter()
127 .map(|arg| {
128 let (new_arg, arg_expanded) = Self::expand_expression(arg, aliases);
129 expanded = expanded || arg_expanded;
130 new_arg
131 })
132 .collect();
133
134 (
135 SqlExpression::FunctionCall {
136 name: name.clone(),
137 args: new_args,
138 distinct: *distinct,
139 },
140 expanded,
141 )
142 }
143
144 SqlExpression::InList {
146 expr: inner,
147 values,
148 } => {
149 let (new_expr, expr_expanded) = Self::expand_expression(inner, aliases);
150 let mut expanded = expr_expanded;
151
152 let new_values: Vec<SqlExpression> = values
153 .iter()
154 .map(|val| {
155 let (new_val, val_expanded) = Self::expand_expression(val, aliases);
156 expanded = expanded || val_expanded;
157 new_val
158 })
159 .collect();
160
161 (
162 SqlExpression::InList {
163 expr: Box::new(new_expr),
164 values: new_values,
165 },
166 expanded,
167 )
168 }
169
170 SqlExpression::NotInList {
172 expr: inner,
173 values,
174 } => {
175 let (new_expr, expr_expanded) = Self::expand_expression(inner, aliases);
176 let mut expanded = expr_expanded;
177
178 let new_values: Vec<SqlExpression> = values
179 .iter()
180 .map(|val| {
181 let (new_val, val_expanded) = Self::expand_expression(val, aliases);
182 expanded = expanded || val_expanded;
183 new_val
184 })
185 .collect();
186
187 (
188 SqlExpression::NotInList {
189 expr: Box::new(new_expr),
190 values: new_values,
191 },
192 expanded,
193 )
194 }
195
196 SqlExpression::Between { expr, lower, upper } => {
198 let (new_expr, expr_expanded) = Self::expand_expression(expr, aliases);
199 let (new_lower, lower_expanded) = Self::expand_expression(lower, aliases);
200 let (new_upper, upper_expanded) = Self::expand_expression(upper, aliases);
201 let expanded = expr_expanded || lower_expanded || upper_expanded;
202
203 (
204 SqlExpression::Between {
205 expr: Box::new(new_expr),
206 lower: Box::new(new_lower),
207 upper: Box::new(new_upper),
208 },
209 expanded,
210 )
211 }
212
213 SqlExpression::CaseExpression {
215 when_branches,
216 else_branch,
217 } => {
218 let mut expanded = false;
219 let new_branches: Vec<_> = when_branches
220 .iter()
221 .map(|branch| {
222 let (new_condition, cond_expanded) =
223 Self::expand_expression(&branch.condition, aliases);
224 let (new_result, result_expanded) =
225 Self::expand_expression(&branch.result, aliases);
226 expanded = expanded || cond_expanded || result_expanded;
227
228 crate::sql::parser::ast::WhenBranch {
229 condition: Box::new(new_condition),
230 result: Box::new(new_result),
231 }
232 })
233 .collect();
234
235 let new_else = else_branch.as_ref().map(|e| {
236 let (new_e, else_expanded) = Self::expand_expression(e, aliases);
237 expanded = expanded || else_expanded;
238 Box::new(new_e)
239 });
240
241 (
242 SqlExpression::CaseExpression {
243 when_branches: new_branches,
244 else_branch: new_else,
245 },
246 expanded,
247 )
248 }
249
250 SqlExpression::SimpleCaseExpression {
252 expr,
253 when_branches,
254 else_branch,
255 } => {
256 let (new_expr, expr_expanded) = Self::expand_expression(expr, aliases);
257 let mut expanded = expr_expanded;
258
259 let new_branches: Vec<_> = when_branches
260 .iter()
261 .map(|branch| {
262 let (new_value, value_expanded) =
263 Self::expand_expression(&branch.value, aliases);
264 let (new_result, result_expanded) =
265 Self::expand_expression(&branch.result, aliases);
266 expanded = expanded || value_expanded || result_expanded;
267
268 crate::sql::parser::ast::SimpleWhenBranch {
269 value: Box::new(new_value),
270 result: Box::new(new_result),
271 }
272 })
273 .collect();
274
275 let new_else = else_branch.as_ref().map(|e| {
276 let (new_e, else_expanded) = Self::expand_expression(e, aliases);
277 expanded = expanded || else_expanded;
278 Box::new(new_e)
279 });
280
281 (
282 SqlExpression::SimpleCaseExpression {
283 expr: Box::new(new_expr),
284 when_branches: new_branches,
285 else_branch: new_else,
286 },
287 expanded,
288 )
289 }
290
291 _ => (expr.clone(), false),
293 }
294 }
295
296 fn expand_where_clause(
298 &mut self,
299 where_clause: &mut crate::sql::parser::ast::WhereClause,
300 aliases: &HashMap<String, SqlExpression>,
301 ) -> bool {
302 let mut any_expanded = false;
303
304 for condition in &mut where_clause.conditions {
305 let (new_expr, expanded) = Self::expand_expression(&condition.expr, aliases);
306 if expanded {
307 condition.expr = new_expr;
308 any_expanded = true;
309 self.expansions += 1;
310 }
311 }
312
313 any_expanded
314 }
315}
316
317impl Default for WhereAliasExpander {
318 fn default() -> Self {
319 Self::new()
320 }
321}
322
323impl ASTTransformer for WhereAliasExpander {
324 fn name(&self) -> &str {
325 "WhereAliasExpander"
326 }
327
328 fn description(&self) -> &str {
329 "Expands SELECT aliases in WHERE clauses to their full expressions"
330 }
331
332 fn transform(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
333 if stmt.where_clause.is_none() {
335 return Ok(stmt);
336 }
337
338 let aliases = Self::extract_aliases(&stmt.select_items);
340
341 if aliases.is_empty() {
342 return Ok(stmt);
344 }
345
346 if let Some(ref mut where_clause) = stmt.where_clause {
348 let expanded = self.expand_where_clause(where_clause, &aliases);
349 if expanded {
350 debug!(
351 "Expanded {} alias reference(s) in WHERE clause",
352 self.expansions
353 );
354 }
355 }
356
357 Ok(stmt)
358 }
359
360 fn begin(&mut self) -> Result<()> {
361 self.expansions = 0;
363 Ok(())
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use crate::sql::parser::ast::{ColumnRef, Condition, QuoteStyle, WhereClause};
371
372 #[test]
373 fn test_extract_aliases() {
374 let double_a_expr = SqlExpression::BinaryOp {
375 left: Box::new(SqlExpression::Column(ColumnRef {
376 name: "a".to_string(),
377 quote_style: QuoteStyle::None,
378 table_prefix: None,
379 })),
380 op: "*".to_string(),
381 right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
382 };
383
384 let select_items = vec![SelectItem::Expression {
385 expr: double_a_expr.clone(),
386 alias: "double_a".to_string(),
387 leading_comments: vec![],
388 trailing_comment: None,
389 }];
390
391 let aliases = WhereAliasExpander::extract_aliases(&select_items);
392 assert_eq!(aliases.len(), 1);
393 assert!(aliases.contains_key("double_a"));
394 }
395
396 #[test]
397 fn test_expand_simple_column_reference() {
398 let aliases = HashMap::from([(
399 "double_a".to_string(),
400 SqlExpression::BinaryOp {
401 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
402 op: "*".to_string(),
403 right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
404 },
405 )]);
406
407 let expr = SqlExpression::Column(ColumnRef::unquoted("double_a".to_string()));
408 let (expanded, changed) = WhereAliasExpander::expand_expression(&expr, &aliases);
409
410 assert!(changed);
411 assert!(matches!(expanded, SqlExpression::BinaryOp { .. }));
412 }
413
414 #[test]
415 fn test_expand_in_binary_op() {
416 let aliases = HashMap::from([(
417 "double_a".to_string(),
418 SqlExpression::BinaryOp {
419 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
420 op: "*".to_string(),
421 right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
422 },
423 )]);
424
425 let expr = SqlExpression::BinaryOp {
426 left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
427 "double_a".to_string(),
428 ))),
429 op: ">".to_string(),
430 right: Box::new(SqlExpression::NumberLiteral("10".to_string())),
431 };
432
433 let (expanded, changed) = WhereAliasExpander::expand_expression(&expr, &aliases);
434
435 assert!(changed);
436 if let SqlExpression::BinaryOp { left, op, right } = expanded {
437 assert_eq!(op, ">");
438 assert!(matches!(left.as_ref(), SqlExpression::BinaryOp { .. }));
439 assert!(matches!(
440 right.as_ref(),
441 SqlExpression::NumberLiteral(s) if s == "10"
442 ));
443 } else {
444 panic!("Expected BinaryOp");
445 }
446 }
447
448 #[test]
449 fn test_transform_with_no_where() {
450 let mut transformer = WhereAliasExpander::new();
451 let stmt = SelectStatement {
452 where_clause: None,
453 ..Default::default()
454 };
455
456 let result = transformer.transform(stmt);
457 assert!(result.is_ok());
458 }
459
460 #[test]
461 fn test_transform_expands_alias() {
462 let mut transformer = WhereAliasExpander::new();
463
464 let double_a_expr = SqlExpression::BinaryOp {
465 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
466 op: "*".to_string(),
467 right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
468 };
469
470 let stmt = SelectStatement {
471 select_items: vec![SelectItem::Expression {
472 expr: double_a_expr.clone(),
473 alias: "double_a".to_string(),
474 leading_comments: vec![],
475 trailing_comment: None,
476 }],
477 where_clause: Some(WhereClause {
478 conditions: vec![Condition {
479 expr: SqlExpression::BinaryOp {
480 left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
481 "double_a".to_string(),
482 ))),
483 op: ">".to_string(),
484 right: Box::new(SqlExpression::NumberLiteral("10".to_string())),
485 },
486 connector: None,
487 }],
488 }),
489 ..Default::default()
490 };
491
492 let result = transformer.transform(stmt).unwrap();
493
494 if let Some(where_clause) = &result.where_clause {
496 if let SqlExpression::BinaryOp { left, .. } = &where_clause.conditions[0].expr {
497 assert!(matches!(left.as_ref(), SqlExpression::BinaryOp { .. }));
499 } else {
500 panic!("Expected BinaryOp in WHERE");
501 }
502 } else {
503 panic!("Expected WHERE clause");
504 }
505
506 assert_eq!(transformer.expansions, 1);
507 }
508
509 #[test]
510 fn test_does_not_expand_table_prefixed_columns() {
511 let aliases = HashMap::from([(
512 "double_a".to_string(),
513 SqlExpression::BinaryOp {
514 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
515 op: "*".to_string(),
516 right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
517 },
518 )]);
519
520 let expr = SqlExpression::Column(ColumnRef {
522 name: "double_a".to_string(),
523 quote_style: QuoteStyle::None,
524 table_prefix: Some("t".to_string()),
525 });
526
527 let (expanded, changed) = WhereAliasExpander::expand_expression(&expr, &aliases);
528
529 assert!(!changed);
530 assert!(matches!(expanded, SqlExpression::Column(_)));
531 }
532}