1use crate::query_plan::pipeline::ASTTransformer;
36use crate::sql::parser::ast::{SelectItem, SelectStatement, SqlExpression};
37use anyhow::Result;
38use std::collections::HashMap;
39use tracing::debug;
40
41pub struct WhereAliasExpander {
43 expansions: usize,
45}
46
47impl WhereAliasExpander {
48 pub fn new() -> Self {
49 Self { expansions: 0 }
50 }
51
52 fn extract_aliases(select_items: &[SelectItem]) -> HashMap<String, SqlExpression> {
55 let mut aliases = HashMap::new();
56
57 for item in select_items {
58 if let SelectItem::Expression { expr, alias, .. } = item {
59 if !alias.is_empty() {
60 aliases.insert(alias.clone(), expr.clone());
61 debug!("Found SELECT alias: {} -> {:?}", alias, expr);
62 }
63 }
64 }
65
66 aliases
67 }
68
69 fn expand_expression(
72 expr: &SqlExpression,
73 aliases: &HashMap<String, SqlExpression>,
74 ) -> (SqlExpression, bool) {
75 match expr {
76 SqlExpression::Column(col_ref) => {
78 if col_ref.table_prefix.is_none() {
80 if let Some(alias_expr) = aliases.get(&col_ref.name) {
81 debug!(
82 "Expanding alias '{}' in WHERE to: {:?}",
83 col_ref.name, alias_expr
84 );
85 return (alias_expr.clone(), true);
86 }
87 }
88 (expr.clone(), false)
89 }
90
91 SqlExpression::BinaryOp { left, op, right } => {
93 let (new_left, left_expanded) = Self::expand_expression(left, aliases);
94 let (new_right, right_expanded) = Self::expand_expression(right, aliases);
95 let expanded = left_expanded || right_expanded;
96
97 (
98 SqlExpression::BinaryOp {
99 left: Box::new(new_left),
100 op: op.clone(),
101 right: Box::new(new_right),
102 },
103 expanded,
104 )
105 }
106
107 SqlExpression::Not { expr: inner } => {
109 let (new_expr, expanded) = Self::expand_expression(inner, aliases);
110 (
111 SqlExpression::Not {
112 expr: Box::new(new_expr),
113 },
114 expanded,
115 )
116 }
117
118 SqlExpression::FunctionCall {
120 name,
121 args,
122 distinct,
123 } => {
124 let mut expanded = false;
125 let new_args: Vec<SqlExpression> = args
126 .iter()
127 .map(|arg| {
128 let (new_arg, arg_expanded) = Self::expand_expression(arg, aliases);
129 expanded = expanded || arg_expanded;
130 new_arg
131 })
132 .collect();
133
134 (
135 SqlExpression::FunctionCall {
136 name: name.clone(),
137 args: new_args,
138 distinct: *distinct,
139 },
140 expanded,
141 )
142 }
143
144 SqlExpression::InList {
146 expr: inner,
147 values,
148 } => {
149 let (new_expr, expr_expanded) = Self::expand_expression(inner, aliases);
150 let mut expanded = expr_expanded;
151
152 let new_values: Vec<SqlExpression> = values
153 .iter()
154 .map(|val| {
155 let (new_val, val_expanded) = Self::expand_expression(val, aliases);
156 expanded = expanded || val_expanded;
157 new_val
158 })
159 .collect();
160
161 (
162 SqlExpression::InList {
163 expr: Box::new(new_expr),
164 values: new_values,
165 },
166 expanded,
167 )
168 }
169
170 SqlExpression::NotInList {
172 expr: inner,
173 values,
174 } => {
175 let (new_expr, expr_expanded) = Self::expand_expression(inner, aliases);
176 let mut expanded = expr_expanded;
177
178 let new_values: Vec<SqlExpression> = values
179 .iter()
180 .map(|val| {
181 let (new_val, val_expanded) = Self::expand_expression(val, aliases);
182 expanded = expanded || val_expanded;
183 new_val
184 })
185 .collect();
186
187 (
188 SqlExpression::NotInList {
189 expr: Box::new(new_expr),
190 values: new_values,
191 },
192 expanded,
193 )
194 }
195
196 SqlExpression::Between { expr, lower, upper } => {
198 let (new_expr, expr_expanded) = Self::expand_expression(expr, aliases);
199 let (new_lower, lower_expanded) = Self::expand_expression(lower, aliases);
200 let (new_upper, upper_expanded) = Self::expand_expression(upper, aliases);
201 let expanded = expr_expanded || lower_expanded || upper_expanded;
202
203 (
204 SqlExpression::Between {
205 expr: Box::new(new_expr),
206 lower: Box::new(new_lower),
207 upper: Box::new(new_upper),
208 },
209 expanded,
210 )
211 }
212
213 SqlExpression::CaseExpression {
215 when_branches,
216 else_branch,
217 } => {
218 let mut expanded = false;
219 let new_branches: Vec<_> = when_branches
220 .iter()
221 .map(|branch| {
222 let (new_condition, cond_expanded) =
223 Self::expand_expression(&branch.condition, aliases);
224 let (new_result, result_expanded) =
225 Self::expand_expression(&branch.result, aliases);
226 expanded = expanded || cond_expanded || result_expanded;
227
228 crate::sql::parser::ast::WhenBranch {
229 condition: Box::new(new_condition),
230 result: Box::new(new_result),
231 }
232 })
233 .collect();
234
235 let new_else = else_branch.as_ref().map(|e| {
236 let (new_e, else_expanded) = Self::expand_expression(e, aliases);
237 expanded = expanded || else_expanded;
238 Box::new(new_e)
239 });
240
241 (
242 SqlExpression::CaseExpression {
243 when_branches: new_branches,
244 else_branch: new_else,
245 },
246 expanded,
247 )
248 }
249
250 SqlExpression::SimpleCaseExpression {
252 expr,
253 when_branches,
254 else_branch,
255 } => {
256 let (new_expr, expr_expanded) = Self::expand_expression(expr, aliases);
257 let mut expanded = expr_expanded;
258
259 let new_branches: Vec<_> = when_branches
260 .iter()
261 .map(|branch| {
262 let (new_value, value_expanded) =
263 Self::expand_expression(&branch.value, aliases);
264 let (new_result, result_expanded) =
265 Self::expand_expression(&branch.result, aliases);
266 expanded = expanded || value_expanded || result_expanded;
267
268 crate::sql::parser::ast::SimpleWhenBranch {
269 value: Box::new(new_value),
270 result: Box::new(new_result),
271 }
272 })
273 .collect();
274
275 let new_else = else_branch.as_ref().map(|e| {
276 let (new_e, else_expanded) = Self::expand_expression(e, aliases);
277 expanded = expanded || else_expanded;
278 Box::new(new_e)
279 });
280
281 (
282 SqlExpression::SimpleCaseExpression {
283 expr: Box::new(new_expr),
284 when_branches: new_branches,
285 else_branch: new_else,
286 },
287 expanded,
288 )
289 }
290
291 SqlExpression::MethodCall {
295 object,
296 method,
297 args,
298 } => {
299 let mut expanded = false;
300 let new_args: Vec<SqlExpression> = args
301 .iter()
302 .map(|arg| {
303 let (new_arg, arg_expanded) = Self::expand_expression(arg, aliases);
304 expanded = expanded || arg_expanded;
305 new_arg
306 })
307 .collect();
308
309 let mut new_object = object.clone();
310 if let Some(SqlExpression::Column(col_ref)) = aliases.get(object) {
311 if col_ref.table_prefix.is_none() {
312 debug!(
313 "Expanding alias '{}' in WHERE method call to column '{}'",
314 object, col_ref.name
315 );
316 new_object = col_ref.name.clone();
317 expanded = true;
318 }
319 }
320
321 (
322 SqlExpression::MethodCall {
323 object: new_object,
324 method: method.clone(),
325 args: new_args,
326 },
327 expanded,
328 )
329 }
330
331 SqlExpression::ChainedMethodCall { base, method, args } => {
334 let (new_base, base_expanded) = Self::expand_expression(base, aliases);
335 let mut expanded = base_expanded;
336 let new_args: Vec<SqlExpression> = args
337 .iter()
338 .map(|arg| {
339 let (new_arg, arg_expanded) = Self::expand_expression(arg, aliases);
340 expanded = expanded || arg_expanded;
341 new_arg
342 })
343 .collect();
344
345 (
346 SqlExpression::ChainedMethodCall {
347 base: Box::new(new_base),
348 method: method.clone(),
349 args: new_args,
350 },
351 expanded,
352 )
353 }
354
355 _ => (expr.clone(), false),
357 }
358 }
359
360 fn expand_where_clause(
362 &mut self,
363 where_clause: &mut crate::sql::parser::ast::WhereClause,
364 aliases: &HashMap<String, SqlExpression>,
365 ) -> bool {
366 let mut any_expanded = false;
367
368 for condition in &mut where_clause.conditions {
369 let (new_expr, expanded) = Self::expand_expression(&condition.expr, aliases);
370 if expanded {
371 condition.expr = new_expr;
372 any_expanded = true;
373 self.expansions += 1;
374 }
375 }
376
377 any_expanded
378 }
379}
380
381impl Default for WhereAliasExpander {
382 fn default() -> Self {
383 Self::new()
384 }
385}
386
387impl ASTTransformer for WhereAliasExpander {
388 fn name(&self) -> &str {
389 "WhereAliasExpander"
390 }
391
392 fn description(&self) -> &str {
393 "Expands SELECT aliases in WHERE clauses to their full expressions"
394 }
395
396 fn transform(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
397 if stmt.where_clause.is_none() {
399 return Ok(stmt);
400 }
401
402 let aliases = Self::extract_aliases(&stmt.select_items);
404
405 if aliases.is_empty() {
406 return Ok(stmt);
408 }
409
410 if let Some(ref mut where_clause) = stmt.where_clause {
412 let expanded = self.expand_where_clause(where_clause, &aliases);
413 if expanded {
414 debug!(
415 "Expanded {} alias reference(s) in WHERE clause",
416 self.expansions
417 );
418 }
419 }
420
421 Ok(stmt)
422 }
423
424 fn begin(&mut self) -> Result<()> {
425 self.expansions = 0;
427 Ok(())
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434 use crate::sql::parser::ast::{ColumnRef, Condition, QuoteStyle, WhereClause};
435
436 #[test]
437 fn test_extract_aliases() {
438 let double_a_expr = SqlExpression::BinaryOp {
439 left: Box::new(SqlExpression::Column(ColumnRef {
440 name: "a".to_string(),
441 quote_style: QuoteStyle::None,
442 table_prefix: None,
443 })),
444 op: "*".to_string(),
445 right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
446 };
447
448 let select_items = vec![SelectItem::Expression {
449 expr: double_a_expr.clone(),
450 alias: "double_a".to_string(),
451 leading_comments: vec![],
452 trailing_comment: None,
453 }];
454
455 let aliases = WhereAliasExpander::extract_aliases(&select_items);
456 assert_eq!(aliases.len(), 1);
457 assert!(aliases.contains_key("double_a"));
458 }
459
460 #[test]
461 fn test_expand_simple_column_reference() {
462 let aliases = HashMap::from([(
463 "double_a".to_string(),
464 SqlExpression::BinaryOp {
465 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
466 op: "*".to_string(),
467 right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
468 },
469 )]);
470
471 let expr = SqlExpression::Column(ColumnRef::unquoted("double_a".to_string()));
472 let (expanded, changed) = WhereAliasExpander::expand_expression(&expr, &aliases);
473
474 assert!(changed);
475 assert!(matches!(expanded, SqlExpression::BinaryOp { .. }));
476 }
477
478 #[test]
479 fn test_expand_in_binary_op() {
480 let aliases = HashMap::from([(
481 "double_a".to_string(),
482 SqlExpression::BinaryOp {
483 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
484 op: "*".to_string(),
485 right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
486 },
487 )]);
488
489 let expr = SqlExpression::BinaryOp {
490 left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
491 "double_a".to_string(),
492 ))),
493 op: ">".to_string(),
494 right: Box::new(SqlExpression::NumberLiteral("10".to_string())),
495 };
496
497 let (expanded, changed) = WhereAliasExpander::expand_expression(&expr, &aliases);
498
499 assert!(changed);
500 if let SqlExpression::BinaryOp { left, op, right } = expanded {
501 assert_eq!(op, ">");
502 assert!(matches!(left.as_ref(), SqlExpression::BinaryOp { .. }));
503 assert!(matches!(
504 right.as_ref(),
505 SqlExpression::NumberLiteral(s) if s == "10"
506 ));
507 } else {
508 panic!("Expected BinaryOp");
509 }
510 }
511
512 #[test]
513 fn test_transform_with_no_where() {
514 let mut transformer = WhereAliasExpander::new();
515 let stmt = SelectStatement {
516 where_clause: None,
517 ..Default::default()
518 };
519
520 let result = transformer.transform(stmt);
521 assert!(result.is_ok());
522 }
523
524 #[test]
525 fn test_transform_expands_alias() {
526 let mut transformer = WhereAliasExpander::new();
527
528 let double_a_expr = SqlExpression::BinaryOp {
529 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
530 op: "*".to_string(),
531 right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
532 };
533
534 let stmt = SelectStatement {
535 select_items: vec![SelectItem::Expression {
536 expr: double_a_expr.clone(),
537 alias: "double_a".to_string(),
538 leading_comments: vec![],
539 trailing_comment: None,
540 }],
541 where_clause: Some(WhereClause {
542 conditions: vec![Condition {
543 expr: SqlExpression::BinaryOp {
544 left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
545 "double_a".to_string(),
546 ))),
547 op: ">".to_string(),
548 right: Box::new(SqlExpression::NumberLiteral("10".to_string())),
549 },
550 connector: None,
551 }],
552 }),
553 ..Default::default()
554 };
555
556 let result = transformer.transform(stmt).unwrap();
557
558 if let Some(where_clause) = &result.where_clause {
560 if let SqlExpression::BinaryOp { left, .. } = &where_clause.conditions[0].expr {
561 assert!(matches!(left.as_ref(), SqlExpression::BinaryOp { .. }));
563 } else {
564 panic!("Expected BinaryOp in WHERE");
565 }
566 } else {
567 panic!("Expected WHERE clause");
568 }
569
570 assert_eq!(transformer.expansions, 1);
571 }
572
573 #[test]
574 fn test_expand_alias_in_method_call_receiver() {
575 let aliases = HashMap::from([(
579 "name".to_string(),
580 SqlExpression::Column(ColumnRef {
581 name: "name.common".to_string(),
582 quote_style: QuoteStyle::DoubleQuotes,
583 table_prefix: None,
584 }),
585 )]);
586
587 let expr = SqlExpression::MethodCall {
588 object: "name".to_string(),
589 method: "Contains".to_string(),
590 args: vec![SqlExpression::StringLiteral("united".to_string())],
591 };
592
593 let (expanded, changed) = WhereAliasExpander::expand_expression(&expr, &aliases);
594
595 assert!(changed);
596 match expanded {
597 SqlExpression::MethodCall { object, method, .. } => {
598 assert_eq!(object, "name.common");
599 assert_eq!(method, "Contains");
600 }
601 other => panic!("Expected MethodCall, got {other:?}"),
602 }
603 }
604
605 #[test]
606 fn test_does_not_expand_method_call_for_nonalias() {
607 let aliases = HashMap::from([(
609 "name".to_string(),
610 SqlExpression::Column(ColumnRef::unquoted("name.common".to_string())),
611 )]);
612
613 let expr = SqlExpression::MethodCall {
614 object: "capital".to_string(),
615 method: "Contains".to_string(),
616 args: vec![SqlExpression::StringLiteral("x".to_string())],
617 };
618
619 let (expanded, changed) = WhereAliasExpander::expand_expression(&expr, &aliases);
620
621 assert!(!changed);
622 assert!(matches!(
623 expanded,
624 SqlExpression::MethodCall { object, .. } if object == "capital"
625 ));
626 }
627
628 #[test]
629 fn test_does_not_expand_table_prefixed_columns() {
630 let aliases = HashMap::from([(
631 "double_a".to_string(),
632 SqlExpression::BinaryOp {
633 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
634 op: "*".to_string(),
635 right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
636 },
637 )]);
638
639 let expr = SqlExpression::Column(ColumnRef {
641 name: "double_a".to_string(),
642 quote_style: QuoteStyle::None,
643 table_prefix: Some("t".to_string()),
644 });
645
646 let (expanded, changed) = WhereAliasExpander::expand_expression(&expr, &aliases);
647
648 assert!(!changed);
649 assert!(matches!(expanded, SqlExpression::Column(_)));
650 }
651}