sql_cli/query_plan/
where_alias_expander.rs

1//! WHERE clause alias expansion transformer
2//!
3//! This transformer allows users to reference SELECT clause aliases in WHERE clauses
4//! by automatically expanding those aliases to their full expressions.
5//!
6//! # Problem
7//!
8//! Users often want to reference complex SELECT expressions by their aliases in WHERE:
9//! ```sql
10//! SELECT a, a * 2 as double_a FROM t WHERE double_a > 10
11//! ```
12//!
13//! This fails because WHERE is evaluated before SELECT, so aliases don't exist yet.
14//!
15//! # Solution
16//!
17//! The transformer rewrites to:
18//! ```sql
19//! SELECT a, a * 2 as double_a FROM t WHERE a * 2 > 10
20//! ```
21//!
22//! # Algorithm
23//!
24//! 1. Extract all aliases from SELECT clause and their corresponding expressions
25//! 2. Scan WHERE clause for column references
26//! 3. If a column reference matches an alias name, replace it with the full expression
27//! 4. Handle nested expressions (BinaryOp, CASE, etc.) recursively
28//!
29//! # Limitations
30//!
31//! - Only works for simple column aliases (not table.alias references)
32//! - Aliases take precedence over actual column names if they conflict
33//! - Complex expressions are duplicated (no common subexpression elimination)
34
35use crate::query_plan::pipeline::ASTTransformer;
36use crate::sql::parser::ast::{SelectItem, SelectStatement, SqlExpression};
37use anyhow::Result;
38use std::collections::HashMap;
39use tracing::debug;
40
41/// Transformer that expands SELECT aliases in WHERE clauses
42pub struct WhereAliasExpander {
43    /// Counter for tracking number of expansions
44    expansions: usize,
45}
46
47impl WhereAliasExpander {
48    pub fn new() -> Self {
49        Self { expansions: 0 }
50    }
51
52    /// Extract aliases from SELECT clause
53    /// Returns a map of alias name -> expression
54    fn extract_aliases(select_items: &[SelectItem]) -> HashMap<String, SqlExpression> {
55        let mut aliases = HashMap::new();
56
57        for item in select_items {
58            if let SelectItem::Expression { expr, alias, .. } = item {
59                if !alias.is_empty() {
60                    aliases.insert(alias.clone(), expr.clone());
61                    debug!("Found SELECT alias: {} -> {:?}", alias, expr);
62                }
63            }
64        }
65
66        aliases
67    }
68
69    /// Recursively expand aliases in an expression
70    /// Returns the expanded expression and whether any expansion occurred
71    fn expand_expression(
72        expr: &SqlExpression,
73        aliases: &HashMap<String, SqlExpression>,
74    ) -> (SqlExpression, bool) {
75        match expr {
76            // Check if this column reference is actually an alias
77            SqlExpression::Column(col_ref) => {
78                // Only expand if it's a simple column (no table prefix)
79                if col_ref.table_prefix.is_none() {
80                    if let Some(alias_expr) = aliases.get(&col_ref.name) {
81                        debug!(
82                            "Expanding alias '{}' in WHERE to: {:?}",
83                            col_ref.name, alias_expr
84                        );
85                        return (alias_expr.clone(), true);
86                    }
87                }
88                (expr.clone(), false)
89            }
90
91            // Recursively expand binary operations
92            SqlExpression::BinaryOp { left, op, right } => {
93                let (new_left, left_expanded) = Self::expand_expression(left, aliases);
94                let (new_right, right_expanded) = Self::expand_expression(right, aliases);
95                let expanded = left_expanded || right_expanded;
96
97                (
98                    SqlExpression::BinaryOp {
99                        left: Box::new(new_left),
100                        op: op.clone(),
101                        right: Box::new(new_right),
102                    },
103                    expanded,
104                )
105            }
106
107            // Expand in NOT expressions
108            SqlExpression::Not { expr: inner } => {
109                let (new_expr, expanded) = Self::expand_expression(inner, aliases);
110                (
111                    SqlExpression::Not {
112                        expr: Box::new(new_expr),
113                    },
114                    expanded,
115                )
116            }
117
118            // Expand in function arguments
119            SqlExpression::FunctionCall {
120                name,
121                args,
122                distinct,
123            } => {
124                let mut expanded = false;
125                let new_args: Vec<SqlExpression> = args
126                    .iter()
127                    .map(|arg| {
128                        let (new_arg, arg_expanded) = Self::expand_expression(arg, aliases);
129                        expanded = expanded || arg_expanded;
130                        new_arg
131                    })
132                    .collect();
133
134                (
135                    SqlExpression::FunctionCall {
136                        name: name.clone(),
137                        args: new_args,
138                        distinct: *distinct,
139                    },
140                    expanded,
141                )
142            }
143
144            // Expand in IN list expressions
145            SqlExpression::InList {
146                expr: inner,
147                values,
148            } => {
149                let (new_expr, expr_expanded) = Self::expand_expression(inner, aliases);
150                let mut expanded = expr_expanded;
151
152                let new_values: Vec<SqlExpression> = values
153                    .iter()
154                    .map(|val| {
155                        let (new_val, val_expanded) = Self::expand_expression(val, aliases);
156                        expanded = expanded || val_expanded;
157                        new_val
158                    })
159                    .collect();
160
161                (
162                    SqlExpression::InList {
163                        expr: Box::new(new_expr),
164                        values: new_values,
165                    },
166                    expanded,
167                )
168            }
169
170            // Expand in NOT IN list expressions
171            SqlExpression::NotInList {
172                expr: inner,
173                values,
174            } => {
175                let (new_expr, expr_expanded) = Self::expand_expression(inner, aliases);
176                let mut expanded = expr_expanded;
177
178                let new_values: Vec<SqlExpression> = values
179                    .iter()
180                    .map(|val| {
181                        let (new_val, val_expanded) = Self::expand_expression(val, aliases);
182                        expanded = expanded || val_expanded;
183                        new_val
184                    })
185                    .collect();
186
187                (
188                    SqlExpression::NotInList {
189                        expr: Box::new(new_expr),
190                        values: new_values,
191                    },
192                    expanded,
193                )
194            }
195
196            // Expand in BETWEEN expressions
197            SqlExpression::Between { expr, lower, upper } => {
198                let (new_expr, expr_expanded) = Self::expand_expression(expr, aliases);
199                let (new_lower, lower_expanded) = Self::expand_expression(lower, aliases);
200                let (new_upper, upper_expanded) = Self::expand_expression(upper, aliases);
201                let expanded = expr_expanded || lower_expanded || upper_expanded;
202
203                (
204                    SqlExpression::Between {
205                        expr: Box::new(new_expr),
206                        lower: Box::new(new_lower),
207                        upper: Box::new(new_upper),
208                    },
209                    expanded,
210                )
211            }
212
213            // Expand in CASE expressions
214            SqlExpression::CaseExpression {
215                when_branches,
216                else_branch,
217            } => {
218                let mut expanded = false;
219                let new_branches: Vec<_> = when_branches
220                    .iter()
221                    .map(|branch| {
222                        let (new_condition, cond_expanded) =
223                            Self::expand_expression(&branch.condition, aliases);
224                        let (new_result, result_expanded) =
225                            Self::expand_expression(&branch.result, aliases);
226                        expanded = expanded || cond_expanded || result_expanded;
227
228                        crate::sql::parser::ast::WhenBranch {
229                            condition: Box::new(new_condition),
230                            result: Box::new(new_result),
231                        }
232                    })
233                    .collect();
234
235                let new_else = else_branch.as_ref().map(|e| {
236                    let (new_e, else_expanded) = Self::expand_expression(e, aliases);
237                    expanded = expanded || else_expanded;
238                    Box::new(new_e)
239                });
240
241                (
242                    SqlExpression::CaseExpression {
243                        when_branches: new_branches,
244                        else_branch: new_else,
245                    },
246                    expanded,
247                )
248            }
249
250            // Expand in simple CASE expressions
251            SqlExpression::SimpleCaseExpression {
252                expr,
253                when_branches,
254                else_branch,
255            } => {
256                let (new_expr, expr_expanded) = Self::expand_expression(expr, aliases);
257                let mut expanded = expr_expanded;
258
259                let new_branches: Vec<_> = when_branches
260                    .iter()
261                    .map(|branch| {
262                        let (new_value, value_expanded) =
263                            Self::expand_expression(&branch.value, aliases);
264                        let (new_result, result_expanded) =
265                            Self::expand_expression(&branch.result, aliases);
266                        expanded = expanded || value_expanded || result_expanded;
267
268                        crate::sql::parser::ast::SimpleWhenBranch {
269                            value: Box::new(new_value),
270                            result: Box::new(new_result),
271                        }
272                    })
273                    .collect();
274
275                let new_else = else_branch.as_ref().map(|e| {
276                    let (new_e, else_expanded) = Self::expand_expression(e, aliases);
277                    expanded = expanded || else_expanded;
278                    Box::new(new_e)
279                });
280
281                (
282                    SqlExpression::SimpleCaseExpression {
283                        expr: Box::new(new_expr),
284                        when_branches: new_branches,
285                        else_branch: new_else,
286                    },
287                    expanded,
288                )
289            }
290
291            // For all other expressions, return as-is
292            _ => (expr.clone(), false),
293        }
294    }
295
296    /// Expand aliases in WHERE clause conditions
297    fn expand_where_clause(
298        &mut self,
299        where_clause: &mut crate::sql::parser::ast::WhereClause,
300        aliases: &HashMap<String, SqlExpression>,
301    ) -> bool {
302        let mut any_expanded = false;
303
304        for condition in &mut where_clause.conditions {
305            let (new_expr, expanded) = Self::expand_expression(&condition.expr, aliases);
306            if expanded {
307                condition.expr = new_expr;
308                any_expanded = true;
309                self.expansions += 1;
310            }
311        }
312
313        any_expanded
314    }
315}
316
317impl Default for WhereAliasExpander {
318    fn default() -> Self {
319        Self::new()
320    }
321}
322
323impl ASTTransformer for WhereAliasExpander {
324    fn name(&self) -> &str {
325        "WhereAliasExpander"
326    }
327
328    fn description(&self) -> &str {
329        "Expands SELECT aliases in WHERE clauses to their full expressions"
330    }
331
332    fn transform(&mut self, mut stmt: SelectStatement) -> Result<SelectStatement> {
333        // Only process if there's a WHERE clause
334        if stmt.where_clause.is_none() {
335            return Ok(stmt);
336        }
337
338        // Step 1: Extract all aliases from SELECT clause
339        let aliases = Self::extract_aliases(&stmt.select_items);
340
341        if aliases.is_empty() {
342            // No aliases to expand
343            return Ok(stmt);
344        }
345
346        // Step 2: Expand aliases in WHERE clause
347        if let Some(ref mut where_clause) = stmt.where_clause {
348            let expanded = self.expand_where_clause(where_clause, &aliases);
349            if expanded {
350                debug!(
351                    "Expanded {} alias reference(s) in WHERE clause",
352                    self.expansions
353                );
354            }
355        }
356
357        Ok(stmt)
358    }
359
360    fn begin(&mut self) -> Result<()> {
361        // Reset expansion counter for each query
362        self.expansions = 0;
363        Ok(())
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use crate::sql::parser::ast::{ColumnRef, Condition, QuoteStyle, WhereClause};
371
372    #[test]
373    fn test_extract_aliases() {
374        let double_a_expr = SqlExpression::BinaryOp {
375            left: Box::new(SqlExpression::Column(ColumnRef {
376                name: "a".to_string(),
377                quote_style: QuoteStyle::None,
378                table_prefix: None,
379            })),
380            op: "*".to_string(),
381            right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
382        };
383
384        let select_items = vec![SelectItem::Expression {
385            expr: double_a_expr.clone(),
386            alias: "double_a".to_string(),
387            leading_comments: vec![],
388            trailing_comment: None,
389        }];
390
391        let aliases = WhereAliasExpander::extract_aliases(&select_items);
392        assert_eq!(aliases.len(), 1);
393        assert!(aliases.contains_key("double_a"));
394    }
395
396    #[test]
397    fn test_expand_simple_column_reference() {
398        let aliases = HashMap::from([(
399            "double_a".to_string(),
400            SqlExpression::BinaryOp {
401                left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
402                op: "*".to_string(),
403                right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
404            },
405        )]);
406
407        let expr = SqlExpression::Column(ColumnRef::unquoted("double_a".to_string()));
408        let (expanded, changed) = WhereAliasExpander::expand_expression(&expr, &aliases);
409
410        assert!(changed);
411        assert!(matches!(expanded, SqlExpression::BinaryOp { .. }));
412    }
413
414    #[test]
415    fn test_expand_in_binary_op() {
416        let aliases = HashMap::from([(
417            "double_a".to_string(),
418            SqlExpression::BinaryOp {
419                left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
420                op: "*".to_string(),
421                right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
422            },
423        )]);
424
425        let expr = SqlExpression::BinaryOp {
426            left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
427                "double_a".to_string(),
428            ))),
429            op: ">".to_string(),
430            right: Box::new(SqlExpression::NumberLiteral("10".to_string())),
431        };
432
433        let (expanded, changed) = WhereAliasExpander::expand_expression(&expr, &aliases);
434
435        assert!(changed);
436        if let SqlExpression::BinaryOp { left, op, right } = expanded {
437            assert_eq!(op, ">");
438            assert!(matches!(left.as_ref(), SqlExpression::BinaryOp { .. }));
439            assert!(matches!(
440                right.as_ref(),
441                SqlExpression::NumberLiteral(s) if s == "10"
442            ));
443        } else {
444            panic!("Expected BinaryOp");
445        }
446    }
447
448    #[test]
449    fn test_transform_with_no_where() {
450        let mut transformer = WhereAliasExpander::new();
451        let stmt = SelectStatement {
452            where_clause: None,
453            ..Default::default()
454        };
455
456        let result = transformer.transform(stmt);
457        assert!(result.is_ok());
458    }
459
460    #[test]
461    fn test_transform_expands_alias() {
462        let mut transformer = WhereAliasExpander::new();
463
464        let double_a_expr = SqlExpression::BinaryOp {
465            left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
466            op: "*".to_string(),
467            right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
468        };
469
470        let stmt = SelectStatement {
471            select_items: vec![SelectItem::Expression {
472                expr: double_a_expr.clone(),
473                alias: "double_a".to_string(),
474                leading_comments: vec![],
475                trailing_comment: None,
476            }],
477            where_clause: Some(WhereClause {
478                conditions: vec![Condition {
479                    expr: SqlExpression::BinaryOp {
480                        left: Box::new(SqlExpression::Column(ColumnRef::unquoted(
481                            "double_a".to_string(),
482                        ))),
483                        op: ">".to_string(),
484                        right: Box::new(SqlExpression::NumberLiteral("10".to_string())),
485                    },
486                    connector: None,
487                }],
488            }),
489            ..Default::default()
490        };
491
492        let result = transformer.transform(stmt).unwrap();
493
494        // Check that WHERE was rewritten
495        if let Some(where_clause) = &result.where_clause {
496            if let SqlExpression::BinaryOp { left, .. } = &where_clause.conditions[0].expr {
497                // Left side should now be the expanded expression (a * 2), not the column "double_a"
498                assert!(matches!(left.as_ref(), SqlExpression::BinaryOp { .. }));
499            } else {
500                panic!("Expected BinaryOp in WHERE");
501            }
502        } else {
503            panic!("Expected WHERE clause");
504        }
505
506        assert_eq!(transformer.expansions, 1);
507    }
508
509    #[test]
510    fn test_does_not_expand_table_prefixed_columns() {
511        let aliases = HashMap::from([(
512            "double_a".to_string(),
513            SqlExpression::BinaryOp {
514                left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
515                op: "*".to_string(),
516                right: Box::new(SqlExpression::NumberLiteral("2".to_string())),
517            },
518        )]);
519
520        // Column with table prefix should NOT be expanded
521        let expr = SqlExpression::Column(ColumnRef {
522            name: "double_a".to_string(),
523            quote_style: QuoteStyle::None,
524            table_prefix: Some("t".to_string()),
525        });
526
527        let (expanded, changed) = WhereAliasExpander::expand_expression(&expr, &aliases);
528
529        assert!(!changed);
530        assert!(matches!(expanded, SqlExpression::Column(_)));
531    }
532}