Skip to main content

sql_cli/query_plan/
where_alias_expander.rs

1//! WHERE clause alias expansion transformer
2//!
3//! This transformer allows users to reference SELECT clause aliases in WHERE clauses
4//! by automatically expanding those aliases to their full expressions.
5//!
6//! # Problem
7//!
8//! Users often want to reference complex SELECT expressions by their aliases in WHERE:
9//! ```sql
10//! SELECT a, a * 2 as double_a FROM t WHERE double_a > 10
11//! ```
12//!
13//! This fails because WHERE is evaluated before SELECT, so aliases don't exist yet.
14//!
15//! # Solution
16//!
17//! The transformer rewrites to:
18//! ```sql
19//! SELECT a, a * 2 as double_a FROM t WHERE a * 2 > 10
20//! ```
21//!
22//! # Algorithm
23//!
24//! 1. Extract all aliases from SELECT clause and their corresponding expressions
25//! 2. Scan WHERE clause for column references
26//! 3. If a column reference matches an alias name, replace it with the full expression
27//! 4. Handle nested expressions (BinaryOp, CASE, etc.) recursively
28//!
29//! # Limitations
30//!
31//! - Only works for simple column aliases (not table.alias references)
32//! - Aliases take precedence over actual column names if they conflict
33//! - Complex expressions are duplicated (no common subexpression elimination)
34
35use crate::query_plan::pipeline::ASTTransformer;
36use crate::sql::parser::ast::{SelectItem, SelectStatement, SqlExpression};
37use anyhow::Result;
38use std::collections::HashMap;
39use tracing::debug;
40
41/// Transformer that expands SELECT aliases in WHERE clauses
42pub struct WhereAliasExpander {
43    /// Counter for tracking number of expansions
44    expansions: usize,
45}
46
47impl WhereAliasExpander {
48    pub fn new() -> Self {
49        Self { expansions: 0 }
50    }
51
52    /// Extract aliases from SELECT clause
53    /// Returns a map of alias name -> expression
54    fn extract_aliases(select_items: &[SelectItem]) -> HashMap<String, SqlExpression> {
55        let mut aliases = HashMap::new();
56
57        for item in select_items {
58            if let SelectItem::Expression { expr, alias, .. } = item {
59                if !alias.is_empty() {
60                    aliases.insert(alias.clone(), expr.clone());
61                    debug!("Found SELECT alias: {} -> {:?}", alias, expr);
62                }
63            }
64        }
65
66        aliases
67    }
68
69    /// Recursively expand aliases in an expression
70    /// Returns the expanded expression and whether any expansion occurred
71    fn expand_expression(
72        expr: &SqlExpression,
73        aliases: &HashMap<String, SqlExpression>,
74    ) -> (SqlExpression, bool) {
75        match expr {
76            // Check if this column reference is actually an alias
77            SqlExpression::Column(col_ref) => {
78                // Only expand if it's a simple column (no table prefix)
79                if col_ref.table_prefix.is_none() {
80                    if let Some(alias_expr) = aliases.get(&col_ref.name) {
81                        debug!(
82                            "Expanding alias '{}' in WHERE to: {:?}",
83                            col_ref.name, alias_expr
84                        );
85                        return (alias_expr.clone(), true);
86                    }
87                }
88                (expr.clone(), false)
89            }
90
91            // Recursively expand binary operations
92            SqlExpression::BinaryOp { left, op, right } => {
93                let (new_left, left_expanded) = Self::expand_expression(left, aliases);
94                let (new_right, right_expanded) = Self::expand_expression(right, aliases);
95                let expanded = left_expanded || right_expanded;
96
97                (
98                    SqlExpression::BinaryOp {
99                        left: Box::new(new_left),
100                        op: op.clone(),
101                        right: Box::new(new_right),
102                    },
103                    expanded,
104                )
105            }
106
107            // Expand in NOT expressions
108            SqlExpression::Not { expr: inner } => {
109                let (new_expr, expanded) = Self::expand_expression(inner, aliases);
110                (
111                    SqlExpression::Not {
112                        expr: Box::new(new_expr),
113                    },
114                    expanded,
115                )
116            }
117
118            // Expand in function arguments
119            SqlExpression::FunctionCall {
120                name,
121                args,
122                distinct,
123            } => {
124                let mut expanded = false;
125                let new_args: Vec<SqlExpression> = args
126                    .iter()
127                    .map(|arg| {
128                        let (new_arg, arg_expanded) = Self::expand_expression(arg, aliases);
129                        expanded = expanded || arg_expanded;
130                        new_arg
131                    })
132                    .collect();
133
134                (
135                    SqlExpression::FunctionCall {
136                        name: name.clone(),
137                        args: new_args,
138                        distinct: *distinct,
139                    },
140                    expanded,
141                )
142            }
143
144            // Expand in IN list expressions
145            SqlExpression::InList {
146                expr: inner,
147                values,
148            } => {
149                let (new_expr, expr_expanded) = Self::expand_expression(inner, aliases);
150                let mut expanded = expr_expanded;
151
152                let new_values: Vec<SqlExpression> = values
153                    .iter()
154                    .map(|val| {
155                        let (new_val, val_expanded) = Self::expand_expression(val, aliases);
156                        expanded = expanded || val_expanded;
157                        new_val
158                    })
159                    .collect();
160
161                (
162                    SqlExpression::InList {
163                        expr: Box::new(new_expr),
164                        values: new_values,
165                    },
166                    expanded,
167                )
168            }
169
170            // Expand in NOT IN list expressions
171            SqlExpression::NotInList {
172                expr: inner,
173                values,
174            } => {
175                let (new_expr, expr_expanded) = Self::expand_expression(inner, aliases);
176                let mut expanded = expr_expanded;
177
178                let new_values: Vec<SqlExpression> = values
179                    .iter()
180                    .map(|val| {
181                        let (new_val, val_expanded) = Self::expand_expression(val, aliases);
182                        expanded = expanded || val_expanded;
183                        new_val
184                    })
185                    .collect();
186
187                (
188                    SqlExpression::NotInList {
189                        expr: Box::new(new_expr),
190                        values: new_values,
191                    },
192                    expanded,
193                )
194            }
195
196            // Expand in BETWEEN expressions
197            SqlExpression::Between { expr, lower, upper } => {
198                let (new_expr, expr_expanded) = Self::expand_expression(expr, aliases);
199                let (new_lower, lower_expanded) = Self::expand_expression(lower, aliases);
200                let (new_upper, upper_expanded) = Self::expand_expression(upper, aliases);
201                let expanded = expr_expanded || lower_expanded || upper_expanded;
202
203                (
204                    SqlExpression::Between {
205                        expr: Box::new(new_expr),
206                        lower: Box::new(new_lower),
207                        upper: Box::new(new_upper),
208                    },
209                    expanded,
210                )
211            }
212
213            // Expand in CASE expressions
214            SqlExpression::CaseExpression {
215                when_branches,
216                else_branch,
217            } => {
218                let mut expanded = false;
219                let new_branches: Vec<_> = when_branches
220                    .iter()
221                    .map(|branch| {
222                        let (new_condition, cond_expanded) =
223                            Self::expand_expression(&branch.condition, aliases);
224                        let (new_result, result_expanded) =
225                            Self::expand_expression(&branch.result, aliases);
226                        expanded = expanded || cond_expanded || result_expanded;
227
228                        crate::sql::parser::ast::WhenBranch {
229                            condition: Box::new(new_condition),
230                            result: Box::new(new_result),
231                        }
232                    })
233                    .collect();
234
235                let new_else = else_branch.as_ref().map(|e| {
236                    let (new_e, else_expanded) = Self::expand_expression(e, aliases);
237                    expanded = expanded || else_expanded;
238                    Box::new(new_e)
239                });
240
241                (
242                    SqlExpression::CaseExpression {
243                        when_branches: new_branches,
244                        else_branch: new_else,
245                    },
246                    expanded,
247                )
248            }
249
250            // Expand in simple CASE expressions
251            SqlExpression::SimpleCaseExpression {
252                expr,
253                when_branches,
254                else_branch,
255            } => {
256                let (new_expr, expr_expanded) = Self::expand_expression(expr, aliases);
257                let mut expanded = expr_expanded;
258
259                let new_branches: Vec<_> = when_branches
260                    .iter()
261                    .map(|branch| {
262                        let (new_value, value_expanded) =
263                            Self::expand_expression(&branch.value, aliases);
264                        let (new_result, result_expanded) =
265                            Self::expand_expression(&branch.result, aliases);
266                        expanded = expanded || value_expanded || result_expanded;
267
268                        crate::sql::parser::ast::SimpleWhenBranch {
269                            value: Box::new(new_value),
270                            result: Box::new(new_result),
271                        }
272                    })
273                    .collect();
274
275                let new_else = else_branch.as_ref().map(|e| {
276                    let (new_e, else_expanded) = Self::expand_expression(e, aliases);
277                    expanded = expanded || else_expanded;
278                    Box::new(new_e)
279                });
280
281                (
282                    SqlExpression::SimpleCaseExpression {
283                        expr: Box::new(new_expr),
284                        when_branches: new_branches,
285                        else_branch: new_else,
286                    },
287                    expanded,
288                )
289            }
290
291            // Expand in method calls, e.g. `alias.Contains('x')`.
292            // The receiver is a bare column-name string, so an alias can only be
293            // substituted if it resolves to a simple (un-prefixed) column.
294            SqlExpression::MethodCall {
295                object,
296                method,
297                args,
298            } => {
299                let mut expanded = false;
300                let new_args: Vec<SqlExpression> = args
301                    .iter()
302                    .map(|arg| {
303                        let (new_arg, arg_expanded) = Self::expand_expression(arg, aliases);
304                        expanded = expanded || arg_expanded;
305                        new_arg
306                    })
307                    .collect();
308
309                let mut new_object = object.clone();
310                if let Some(SqlExpression::Column(col_ref)) = aliases.get(object) {
311                    if col_ref.table_prefix.is_none() {
312                        debug!(
313                            "Expanding alias '{}' in WHERE method call to column '{}'",
314                            object, col_ref.name
315                        );
316                        new_object = col_ref.name.clone();
317                        expanded = true;
318                    }
319                }
320
321                (
322                    SqlExpression::MethodCall {
323                        object: new_object,
324                        method: method.clone(),
325                        args: new_args,
326                    },
327                    expanded,
328                )
329            }
330
331            // Expand in chained method calls, e.g. `(alias).Trim().Contains('x')`.
332            // The base is itself an expression, so recurse into it normally.
333            SqlExpression::ChainedMethodCall { base, method, args } => {
334                let (new_base, base_expanded) = Self::expand_expression(base, aliases);
335                let mut expanded = base_expanded;
336                let new_args: Vec<SqlExpression> = args
337                    .iter()
338                    .map(|arg| {
339                        let (new_arg, arg_expanded) = Self::expand_expression(arg, aliases);
340                        expanded = expanded || arg_expanded;
341                        new_arg
342                    })
343                    .collect();
344
345                (
346                    SqlExpression::ChainedMethodCall {
347                        base: Box::new(new_base),
348                        method: method.clone(),
349                        args: new_args,
350                    },
351                    expanded,
352                )
353            }
354
355            // For all other expressions, return as-is
356            _ => (expr.clone(), false),
357        }
358    }
359
360    /// Expand aliases in WHERE clause conditions
361    fn expand_where_clause(
362        &mut self,
363        where_clause: &mut crate::sql::parser::ast::WhereClause,
364        aliases: &HashMap<String, SqlExpression>,
365    ) -> bool {
366        let mut any_expanded = false;
367
368        for condition in &mut where_clause.conditions {
369            let (new_expr, expanded) = Self::expand_expression(&condition.expr, aliases);
370            if expanded {
371                condition.expr = new_expr;
372                any_expanded = true;
373                self.expansions += 1;
374            }
375        }
376
377        any_expanded
378    }
379}
380
381impl Default for WhereAliasExpander {
382    fn default() -> Self {
383        Self::new()
384    }
385}
386
387impl ASTTransformer for WhereAliasExpander {
388    fn name(&self) -> &str {
389        "WhereAliasExpander"
390    }
391
392    fn description(&self) -> &str {
393        "Expands SELECT aliases in WHERE clauses to their full expressions"
394    }
395
396    fn transform(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
397        // Only process if there's a WHERE clause
398        if stmt.where_clause.is_none() {
399            return Ok(stmt);
400        }
401
402        // Step 1: Extract all aliases from SELECT clause
403        let aliases = Self::extract_aliases(&stmt.select_items);
404
405        if aliases.is_empty() {
406            // No aliases to expand
407            return Ok(stmt);
408        }
409
410        // Step 2: Expand aliases in WHERE clause
411        if let Some(ref mut where_clause) = stmt.where_clause {
412            let expanded = self.expand_where_clause(where_clause, &aliases);
413            if expanded {
414                debug!(
415                    "Expanded {} alias reference(s) in WHERE clause",
416                    self.expansions
417                );
418            }
419        }
420
421        Ok(stmt)
422    }
423
424    fn begin(&mut self) -> Result<()> {
425        // Reset expansion counter for each query
426        self.expansions = 0;
427        Ok(())
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434    use crate::sql::parser::ast::{ColumnRef, Condition, QuoteStyle, WhereClause};
435
436    #[test]
437    fn test_extract_aliases() {
438        let double_a_expr = SqlExpression::BinaryOp {
439            left: Box::new(SqlExpression::Column(ColumnRef {
440                name: "a".to_string(),
441                quote_style: QuoteStyle::None,
442                table_prefix: None,
443            })),
444            op: "*".to_string(),
445            right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
446        };
447
448        let select_items = vec![SelectItem::Expression {
449            expr: double_a_expr.clone(),
450            alias: "double_a".to_string(),
451            leading_comments: vec![],
452            trailing_comment: None,
453        }];
454
455        let aliases = WhereAliasExpander::extract_aliases(&select_items);
456        assert_eq!(aliases.len(), 1);
457        assert!(aliases.contains_key("double_a"));
458    }
459
460    #[test]
461    fn test_expand_simple_column_reference() {
462        let aliases = HashMap::from([(
463            "double_a".to_string(),
464            SqlExpression::BinaryOp {
465                left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
466                op: "*".to_string(),
467                right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
468            },
469        )]);
470
471        let expr = SqlExpression::Column(ColumnRef::unquoted("double_a".to_string()));
472        let (expanded, changed) = WhereAliasExpander::expand_expression(&expr, &aliases);
473
474        assert!(changed);
475        assert!(matches!(expanded, SqlExpression::BinaryOp { .. }));
476    }
477
478    #[test]
479    fn test_expand_in_binary_op() {
480        let aliases = HashMap::from([(
481            "double_a".to_string(),
482            SqlExpression::BinaryOp {
483                left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
484                op: "*".to_string(),
485                right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
486            },
487        )]);
488
489        let expr = SqlExpression::BinaryOp {
490            left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
491                "double_a".to_string(),
492            ))),
493            op: ">".to_string(),
494            right: Box::new(SqlExpression::NumberLiteral("10".to_string())),
495        };
496
497        let (expanded, changed) = WhereAliasExpander::expand_expression(&expr, &aliases);
498
499        assert!(changed);
500        if let SqlExpression::BinaryOp { left, op, right } = expanded {
501            assert_eq!(op, ">");
502            assert!(matches!(left.as_ref(), SqlExpression::BinaryOp { .. }));
503            assert!(matches!(
504                right.as_ref(),
505                SqlExpression::NumberLiteral(s) if s == "10"
506            ));
507        } else {
508            panic!("Expected BinaryOp");
509        }
510    }
511
512    #[test]
513    fn test_transform_with_no_where() {
514        let mut transformer = WhereAliasExpander::new();
515        let stmt = SelectStatement {
516            where_clause: None,
517            ..Default::default()
518        };
519
520        let result = transformer.transform(stmt);
521        assert!(result.is_ok());
522    }
523
524    #[test]
525    fn test_transform_expands_alias() {
526        let mut transformer = WhereAliasExpander::new();
527
528        let double_a_expr = SqlExpression::BinaryOp {
529            left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
530            op: "*".to_string(),
531            right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
532        };
533
534        let stmt = SelectStatement {
535            select_items: vec![SelectItem::Expression {
536                expr: double_a_expr.clone(),
537                alias: "double_a".to_string(),
538                leading_comments: vec![],
539                trailing_comment: None,
540            }],
541            where_clause: Some(WhereClause {
542                conditions: vec![Condition {
543                    expr: SqlExpression::BinaryOp {
544                        left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
545                            "double_a".to_string(),
546                        ))),
547                        op: ">".to_string(),
548                        right: Box::new(SqlExpression::NumberLiteral("10".to_string())),
549                    },
550                    connector: None,
551                }],
552            }),
553            ..Default::default()
554        };
555
556        let result = transformer.transform(stmt).unwrap();
557
558        // Check that WHERE was rewritten
559        if let Some(where_clause) = &result.where_clause {
560            if let SqlExpression::BinaryOp { left, .. } = &where_clause.conditions[0].expr {
561                // Left side should now be the expanded expression (a * 2), not the column "double_a"
562                assert!(matches!(left.as_ref(), SqlExpression::BinaryOp { .. }));
563            } else {
564                panic!("Expected BinaryOp in WHERE");
565            }
566        } else {
567            panic!("Expected WHERE clause");
568        }
569
570        assert_eq!(transformer.expansions, 1);
571    }
572
573    #[test]
574    fn test_expand_alias_in_method_call_receiver() {
575        // `SELECT "name.common" as name ... WHERE name.Contains('x')`
576        // The alias `name` resolves to the column `name.common`, so the method
577        // call's receiver should be rewritten to that column name.
578        let aliases = HashMap::from([(
579            "name".to_string(),
580            SqlExpression::Column(ColumnRef {
581                name: "name.common".to_string(),
582                quote_style: QuoteStyle::DoubleQuotes,
583                table_prefix: None,
584            }),
585        )]);
586
587        let expr = SqlExpression::MethodCall {
588            object: "name".to_string(),
589            method: "Contains".to_string(),
590            args: vec![SqlExpression::StringLiteral("united".to_string())],
591        };
592
593        let (expanded, changed) = WhereAliasExpander::expand_expression(&expr, &aliases);
594
595        assert!(changed);
596        match expanded {
597            SqlExpression::MethodCall { object, method, .. } => {
598                assert_eq!(object, "name.common");
599                assert_eq!(method, "Contains");
600            }
601            other => panic!("Expected MethodCall, got {other:?}"),
602        }
603    }
604
605    #[test]
606    fn test_does_not_expand_method_call_for_nonalias() {
607        // A method call whose receiver is a real column (not an alias) is untouched.
608        let aliases = HashMap::from([(
609            "name".to_string(),
610            SqlExpression::Column(ColumnRef::unquoted("name.common".to_string())),
611        )]);
612
613        let expr = SqlExpression::MethodCall {
614            object: "capital".to_string(),
615            method: "Contains".to_string(),
616            args: vec![SqlExpression::StringLiteral("x".to_string())],
617        };
618
619        let (expanded, changed) = WhereAliasExpander::expand_expression(&expr, &aliases);
620
621        assert!(!changed);
622        assert!(matches!(
623            expanded,
624            SqlExpression::MethodCall { object, .. } if object == "capital"
625        ));
626    }
627
628    #[test]
629    fn test_does_not_expand_table_prefixed_columns() {
630        let aliases = HashMap::from([(
631            "double_a".to_string(),
632            SqlExpression::BinaryOp {
633                left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
634                op: "*".to_string(),
635                right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
636            },
637        )]);
638
639        // Column with table prefix should NOT be expanded
640        let expr = SqlExpression::Column(ColumnRef {
641            name: "double_a".to_string(),
642            quote_style: QuoteStyle::None,
643            table_prefix: Some("t".to_string()),
644        });
645
646        let (expanded, changed) = WhereAliasExpander::expand_expression(&expr, &aliases);
647
648        assert!(!changed);
649        assert!(matches!(expanded, SqlExpression::Column(_)));
650    }
651}