sql_cli/query_plan/
in_operator_lifter.rs

1use crate::sql::parser::ast::{
2    CTEType, ColumnRef, Condition, SelectItem, SelectStatement, SqlExpression, CTE,
3};
4
5/// Specialized lifter for IN operator expressions with function calls
6pub struct InOperatorLifter {
7    /// Counter for generating unique column names
8    column_counter: usize,
9}
10
11impl InOperatorLifter {
12    pub fn new() -> Self {
13        InOperatorLifter { column_counter: 0 }
14    }
15
16    /// Generate a unique column name for lifted expressions
17    fn next_column_name(&mut self) -> String {
18        self.column_counter += 1;
19        format!("__expr_{}", self.column_counter)
20    }
21
22    /// Check if an IN expression needs lifting (has function call on left side)
23    pub fn needs_in_lifting(expr: &SqlExpression) -> bool {
24        match expr {
25            SqlExpression::InList { expr, .. } | SqlExpression::NotInList { expr, .. } => {
26                Self::is_complex_expression(expr)
27            }
28            _ => false,
29        }
30    }
31
32    /// Check if an expression is complex (not just a column reference)
33    fn is_complex_expression(expr: &SqlExpression) -> bool {
34        !matches!(expr, SqlExpression::Column(_))
35    }
36
37    /// Lift IN expressions with function calls
38    pub fn lift_in_expressions(&mut self, stmt: &mut SelectStatement) -> Vec<LiftedInExpression> {
39        let mut lifted = Vec::new();
40
41        if let Some(ref mut where_clause) = stmt.where_clause {
42            let mut new_conditions = Vec::new();
43
44            for condition in &where_clause.conditions {
45                match &condition.expr {
46                    SqlExpression::InList { expr, values } if Self::is_complex_expression(expr) => {
47                        let column_alias = self.next_column_name();
48
49                        // Record what we're lifting
50                        lifted.push(LiftedInExpression {
51                            original_expr: expr.as_ref().clone(),
52                            alias: column_alias.clone(),
53                            values: values.clone(),
54                            is_not_in: false,
55                        });
56
57                        // Create new simple condition
58                        new_conditions.push(Condition {
59                            expr: SqlExpression::InList {
60                                expr: Box::new(SqlExpression::Column(ColumnRef::unquoted(
61                                    column_alias,
62                                ))),
63                                values: values.clone(),
64                            },
65                            connector: condition.connector.clone(),
66                        });
67                    }
68                    SqlExpression::NotInList { expr, values }
69                        if Self::is_complex_expression(expr) =>
70                    {
71                        let column_alias = self.next_column_name();
72
73                        // Record what we're lifting
74                        lifted.push(LiftedInExpression {
75                            original_expr: expr.as_ref().clone(),
76                            alias: column_alias.clone(),
77                            values: values.clone(),
78                            is_not_in: true,
79                        });
80
81                        // Create new simple condition
82                        new_conditions.push(Condition {
83                            expr: SqlExpression::NotInList {
84                                expr: Box::new(SqlExpression::Column(ColumnRef::unquoted(
85                                    column_alias,
86                                ))),
87                                values: values.clone(),
88                            },
89                            connector: condition.connector.clone(),
90                        });
91                    }
92                    _ => {
93                        // Keep condition as-is
94                        new_conditions.push(condition.clone());
95                    }
96                }
97            }
98
99            // Update WHERE clause with simplified conditions
100            where_clause.conditions = new_conditions;
101        }
102
103        lifted
104    }
105
106    /// Apply lifted expressions to SELECT items
107    pub fn apply_lifted_to_select(
108        &self,
109        stmt: &mut SelectStatement,
110        lifted: &[LiftedInExpression],
111    ) {
112        // Add the computed expressions to the SELECT list
113        for lift in lifted {
114            stmt.select_items.push(SelectItem::Expression {
115                expr: lift.original_expr.clone(),
116                alias: lift.alias.clone(),
117            });
118        }
119    }
120
121    /// Create a CTE that includes the lifted expressions
122    pub fn create_lifting_cte(
123        &self,
124        base_table: &str,
125        lifted: &[LiftedInExpression],
126        cte_name: String,
127    ) -> CTE {
128        let mut select_items = vec![SelectItem::Star];
129
130        // Add each lifted expression as a computed column
131        for lift in lifted {
132            select_items.push(SelectItem::Expression {
133                expr: lift.original_expr.clone(),
134                alias: lift.alias.clone(),
135            });
136        }
137
138        let cte_select = SelectStatement {
139            distinct: false,
140            columns: vec!["*".to_string()],
141            select_items,
142            from_table: Some(base_table.to_string()),
143            from_subquery: None,
144            from_function: None,
145            from_alias: None,
146            joins: Vec::new(),
147            where_clause: None,
148            order_by: None,
149            group_by: None,
150            having: None,
151            limit: None,
152            offset: None,
153            ctes: Vec::new(),
154            into_table: None,
155        };
156
157        CTE {
158            name: cte_name,
159            column_list: None,
160            cte_type: CTEType::Standard(cte_select),
161        }
162    }
163
164    /// Rewrite a query to lift IN expressions with function calls
165    pub fn rewrite_query(&mut self, stmt: &mut SelectStatement) -> bool {
166        // Check if we have any IN expressions to lift
167        let has_in_to_lift = if let Some(ref where_clause) = stmt.where_clause {
168            where_clause
169                .conditions
170                .iter()
171                .any(|c| Self::needs_in_lifting(&c.expr))
172        } else {
173            false
174        };
175
176        if !has_in_to_lift {
177            return false;
178        }
179
180        // Extract the base table name
181        let base_table = match &stmt.from_table {
182            Some(table) => table.clone(),
183            None => return false, // Can't lift without a FROM clause
184        };
185
186        // Lift the IN expressions
187        let lifted = self.lift_in_expressions(stmt);
188
189        if lifted.is_empty() {
190            return false;
191        }
192
193        // Create a CTE with the lifted expressions
194        let cte_name = format!("{}_lifted", base_table);
195        let cte = self.create_lifting_cte(&base_table, &lifted, cte_name.clone());
196
197        // Add the CTE to the statement
198        stmt.ctes.push(cte);
199
200        // Update the FROM clause to use the CTE
201        stmt.from_table = Some(cte_name);
202
203        true
204    }
205}
206
207/// Information about a lifted IN expression
208#[derive(Debug, Clone)]
209pub struct LiftedInExpression {
210    /// The original expression (e.g., LOWER(column))
211    pub original_expr: SqlExpression,
212    /// The alias for the computed column
213    pub alias: String,
214    /// The values in the IN list
215    pub values: Vec<SqlExpression>,
216    /// Whether this was NOT IN (vs IN)
217    pub is_not_in: bool,
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn test_needs_in_lifting() {
226        // Simple column IN doesn't need lifting
227        let simple_in = SqlExpression::InList {
228            expr: Box::new(SqlExpression::Column(ColumnRef::unquoted(
229                "col".to_string(),
230            ))),
231            values: vec![SqlExpression::StringLiteral("a".to_string())],
232        };
233        assert!(!InOperatorLifter::needs_in_lifting(&simple_in));
234
235        // Function call IN needs lifting
236        let func_in = SqlExpression::InList {
237            expr: Box::new(SqlExpression::FunctionCall {
238                name: "LOWER".to_string(),
239                args: vec![SqlExpression::Column(ColumnRef::unquoted(
240                    "col".to_string(),
241                ))],
242                distinct: false,
243            }),
244            values: vec![SqlExpression::StringLiteral("a".to_string())],
245        };
246        assert!(InOperatorLifter::needs_in_lifting(&func_in));
247    }
248
249    #[test]
250    fn test_is_complex_expression() {
251        // Column is not complex
252        assert!(!InOperatorLifter::is_complex_expression(
253            &SqlExpression::Column(ColumnRef::unquoted("col".to_string()))
254        ));
255
256        // Function call is complex
257        assert!(InOperatorLifter::is_complex_expression(
258            &SqlExpression::FunctionCall {
259                name: "LOWER".to_string(),
260                args: vec![SqlExpression::Column(ColumnRef::unquoted(
261                    "col".to_string()
262                ))],
263                distinct: false,
264            }
265        ));
266
267        // Binary op is complex
268        assert!(InOperatorLifter::is_complex_expression(
269            &SqlExpression::BinaryOp {
270                left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
271                op: "+".to_string(),
272                right: Box::new(SqlExpression::NumberLiteral("1".to_string())),
273            }
274        ));
275    }
276}