sql_cli/query_plan/
in_operator_lifter.rs

1use crate::sql::parser::ast::{
2    CTEType, ColumnRef, Condition, SelectItem, SelectStatement, SqlExpression, CTE,
3};
4
5/// Specialized lifter for IN operator expressions with function calls
6pub struct InOperatorLifter {
7    /// Counter for generating unique column names
8    column_counter: usize,
9}
10
11impl InOperatorLifter {
12    pub fn new() -> Self {
13        InOperatorLifter { column_counter: 0 }
14    }
15
16    /// Generate a unique column name for lifted expressions
17    fn next_column_name(&mut self) -> String {
18        self.column_counter += 1;
19        format!("__expr_{}", self.column_counter)
20    }
21
22    /// Check if an IN expression needs lifting (has function call on left side)
23    pub fn needs_in_lifting(expr: &SqlExpression) -> bool {
24        match expr {
25            SqlExpression::InList { expr, .. } | SqlExpression::NotInList { expr, .. } => {
26                Self::is_complex_expression(expr)
27            }
28            _ => false,
29        }
30    }
31
32    /// Check if an expression is complex (not just a column reference)
33    fn is_complex_expression(expr: &SqlExpression) -> bool {
34        !matches!(expr, SqlExpression::Column(_))
35    }
36
37    /// Lift IN expressions with function calls
38    pub fn lift_in_expressions(&mut self, stmt: &mut SelectStatement) -> Vec<LiftedInExpression> {
39        let mut lifted = Vec::new();
40
41        if let Some(ref mut where_clause) = stmt.where_clause {
42            let mut new_conditions = Vec::new();
43
44            for condition in &where_clause.conditions {
45                match &condition.expr {
46                    SqlExpression::InList { expr, values } if Self::is_complex_expression(expr) => {
47                        let column_alias = self.next_column_name();
48
49                        // Record what we're lifting
50                        lifted.push(LiftedInExpression {
51                            original_expr: expr.as_ref().clone(),
52                            alias: column_alias.clone(),
53                            values: values.clone(),
54                            is_not_in: false,
55                        });
56
57                        // Create new simple condition
58                        new_conditions.push(Condition {
59                            expr: SqlExpression::InList {
60                                expr: Box::new(SqlExpression::Column(ColumnRef::unquoted(
61                                    column_alias,
62                                ))),
63                                values: values.clone(),
64                            },
65                            connector: condition.connector.clone(),
66                        });
67                    }
68                    SqlExpression::NotInList { expr, values }
69                        if Self::is_complex_expression(expr) =>
70                    {
71                        let column_alias = self.next_column_name();
72
73                        // Record what we're lifting
74                        lifted.push(LiftedInExpression {
75                            original_expr: expr.as_ref().clone(),
76                            alias: column_alias.clone(),
77                            values: values.clone(),
78                            is_not_in: true,
79                        });
80
81                        // Create new simple condition
82                        new_conditions.push(Condition {
83                            expr: SqlExpression::NotInList {
84                                expr: Box::new(SqlExpression::Column(ColumnRef::unquoted(
85                                    column_alias,
86                                ))),
87                                values: values.clone(),
88                            },
89                            connector: condition.connector.clone(),
90                        });
91                    }
92                    _ => {
93                        // Keep condition as-is
94                        new_conditions.push(condition.clone());
95                    }
96                }
97            }
98
99            // Update WHERE clause with simplified conditions
100            where_clause.conditions = new_conditions;
101        }
102
103        lifted
104    }
105
106    /// Apply lifted expressions to SELECT items
107    pub fn apply_lifted_to_select(
108        &self,
109        stmt: &mut SelectStatement,
110        lifted: &[LiftedInExpression],
111    ) {
112        // Add the computed expressions to the SELECT list
113        for lift in lifted {
114            stmt.select_items.push(SelectItem::Expression {
115                expr: lift.original_expr.clone(),
116                alias: lift.alias.clone(),
117                leading_comments: vec![],
118                trailing_comment: None,
119            });
120        }
121    }
122
123    /// Create a CTE that includes the lifted expressions
124    pub fn create_lifting_cte(
125        &self,
126        base_table: &str,
127        lifted: &[LiftedInExpression],
128        cte_name: String,
129    ) -> CTE {
130        let mut select_items = vec![SelectItem::Star {
131            leading_comments: vec![],
132            trailing_comment: None,
133        }];
134
135        // Add each lifted expression as a computed column
136        for lift in lifted {
137            select_items.push(SelectItem::Expression {
138                expr: lift.original_expr.clone(),
139                alias: lift.alias.clone(),
140                leading_comments: vec![],
141                trailing_comment: None,
142            });
143        }
144
145        let cte_select = SelectStatement {
146            distinct: false,
147            columns: vec!["*".to_string()],
148            select_items,
149            from_table: Some(base_table.to_string()),
150            from_subquery: None,
151            from_function: None,
152            from_alias: None,
153            joins: Vec::new(),
154            where_clause: None,
155            order_by: None,
156            group_by: None,
157            having: None,
158            limit: None,
159            offset: None,
160            ctes: Vec::new(),
161            into_table: None,
162            set_operations: Vec::new(),
163            leading_comments: vec![],
164            trailing_comment: None,
165        };
166
167        CTE {
168            name: cte_name,
169            column_list: None,
170            cte_type: CTEType::Standard(cte_select),
171        }
172    }
173
174    /// Rewrite a query to lift IN expressions with function calls
175    pub fn rewrite_query(&mut self, stmt: &mut SelectStatement) -> bool {
176        // Check if we have any IN expressions to lift
177        let has_in_to_lift = if let Some(ref where_clause) = stmt.where_clause {
178            where_clause
179                .conditions
180                .iter()
181                .any(|c| Self::needs_in_lifting(&c.expr))
182        } else {
183            false
184        };
185
186        if !has_in_to_lift {
187            return false;
188        }
189
190        // Extract the base table name
191        let base_table = match &stmt.from_table {
192            Some(table) => table.clone(),
193            None => return false, // Can't lift without a FROM clause
194        };
195
196        // Lift the IN expressions
197        let lifted = self.lift_in_expressions(stmt);
198
199        if lifted.is_empty() {
200            return false;
201        }
202
203        // Create a CTE with the lifted expressions
204        let cte_name = format!("{}_lifted", base_table);
205        let cte = self.create_lifting_cte(&base_table, &lifted, cte_name.clone());
206
207        // Add the CTE to the statement
208        stmt.ctes.push(cte);
209
210        // Update the FROM clause to use the CTE
211        stmt.from_table = Some(cte_name);
212
213        true
214    }
215}
216
217/// Information about a lifted IN expression
218#[derive(Debug, Clone)]
219pub struct LiftedInExpression {
220    /// The original expression (e.g., LOWER(column))
221    pub original_expr: SqlExpression,
222    /// The alias for the computed column
223    pub alias: String,
224    /// The values in the IN list
225    pub values: Vec<SqlExpression>,
226    /// Whether this was NOT IN (vs IN)
227    pub is_not_in: bool,
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233
234    #[test]
235    fn test_needs_in_lifting() {
236        // Simple column IN doesn't need lifting
237        let simple_in = SqlExpression::InList {
238            expr: Box::new(SqlExpression::Column(ColumnRef::unquoted(
239                "col".to_string(),
240            ))),
241            values: vec![SqlExpression::StringLiteral("a".to_string())],
242        };
243        assert!(!InOperatorLifter::needs_in_lifting(&simple_in));
244
245        // Function call IN needs lifting
246        let func_in = SqlExpression::InList {
247            expr: Box::new(SqlExpression::FunctionCall {
248                name: "LOWER".to_string(),
249                args: vec![SqlExpression::Column(ColumnRef::unquoted(
250                    "col".to_string(),
251                ))],
252                distinct: false,
253            }),
254            values: vec![SqlExpression::StringLiteral("a".to_string())],
255        };
256        assert!(InOperatorLifter::needs_in_lifting(&func_in));
257    }
258
259    #[test]
260    fn test_is_complex_expression() {
261        // Column is not complex
262        assert!(!InOperatorLifter::is_complex_expression(
263            &SqlExpression::Column(ColumnRef::unquoted("col".to_string()))
264        ));
265
266        // Function call is complex
267        assert!(InOperatorLifter::is_complex_expression(
268            &SqlExpression::FunctionCall {
269                name: "LOWER".to_string(),
270                args: vec![SqlExpression::Column(ColumnRef::unquoted(
271                    "col".to_string()
272                ))],
273                distinct: false,
274            }
275        ));
276
277        // Binary op is complex
278        assert!(InOperatorLifter::is_complex_expression(
279            &SqlExpression::BinaryOp {
280                left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
281                op: "+".to_string(),
282                right: Box::new(SqlExpression::NumberLiteral("1".to_string())),
283            }
284        ));
285    }
286}