Skip to main content

sql_cli/query_plan/
in_operator_lifter.rs

1use crate::sql::parser::ast::{
2    CTEType, ColumnRef, Condition, SelectItem, SelectStatement, SqlExpression, TableSource, CTE,
3};
4
5/// Specialized lifter for IN operator expressions with function calls
6pub struct InOperatorLifter {
7    /// Counter for generating unique column names
8    column_counter: usize,
9}
10
11impl InOperatorLifter {
12    pub fn new() -> Self {
13        InOperatorLifter { column_counter: 0 }
14    }
15
16    /// Generate a unique column name for lifted expressions
17    fn next_column_name(&mut self) -> String {
18        self.column_counter += 1;
19        format!("__expr_{}", self.column_counter)
20    }
21
22    /// Check if an IN expression needs lifting (has function call on left side)
23    pub fn needs_in_lifting(expr: &SqlExpression) -> bool {
24        match expr {
25            SqlExpression::InList { expr, .. } | SqlExpression::NotInList { expr, .. } => {
26                Self::is_complex_expression(expr)
27            }
28            _ => false,
29        }
30    }
31
32    /// Check if an expression is complex (not just a column reference)
33    fn is_complex_expression(expr: &SqlExpression) -> bool {
34        !matches!(expr, SqlExpression::Column(_))
35    }
36
37    /// Lift IN expressions with function calls
38    pub fn lift_in_expressions(&mut self, stmt: &mut SelectStatement) -> Vec<LiftedInExpression> {
39        let mut lifted = Vec::new();
40
41        if let Some(ref mut where_clause) = stmt.where_clause {
42            let mut new_conditions = Vec::new();
43
44            for condition in &where_clause.conditions {
45                match &condition.expr {
46                    SqlExpression::InList { expr, values } if Self::is_complex_expression(expr) => {
47                        let column_alias = self.next_column_name();
48
49                        // Record what we're lifting
50                        lifted.push(LiftedInExpression {
51                            original_expr: expr.as_ref().clone(),
52                            alias: column_alias.clone(),
53                            values: values.clone(),
54                            is_not_in: false,
55                        });
56
57                        // Create new simple condition
58                        new_conditions.push(Condition {
59                            expr: SqlExpression::InList {
60                                expr: Box::new(SqlExpression::Column(ColumnRef::unquoted(
61                                    column_alias,
62                                ))),
63                                values: values.clone(),
64                            },
65                            connector: condition.connector.clone(),
66                        });
67                    }
68                    SqlExpression::NotInList { expr, values }
69                        if Self::is_complex_expression(expr) =>
70                    {
71                        let column_alias = self.next_column_name();
72
73                        // Record what we're lifting
74                        lifted.push(LiftedInExpression {
75                            original_expr: expr.as_ref().clone(),
76                            alias: column_alias.clone(),
77                            values: values.clone(),
78                            is_not_in: true,
79                        });
80
81                        // Create new simple condition
82                        new_conditions.push(Condition {
83                            expr: SqlExpression::NotInList {
84                                expr: Box::new(SqlExpression::Column(ColumnRef::unquoted(
85                                    column_alias,
86                                ))),
87                                values: values.clone(),
88                            },
89                            connector: condition.connector.clone(),
90                        });
91                    }
92                    _ => {
93                        // Keep condition as-is
94                        new_conditions.push(condition.clone());
95                    }
96                }
97            }
98
99            // Update WHERE clause with simplified conditions
100            where_clause.conditions = new_conditions;
101        }
102
103        lifted
104    }
105
106    /// Apply lifted expressions to SELECT items
107    pub fn apply_lifted_to_select(
108        &self,
109        stmt: &mut SelectStatement,
110        lifted: &[LiftedInExpression],
111    ) {
112        // Add the computed expressions to the SELECT list
113        for lift in lifted {
114            stmt.select_items.push(SelectItem::Expression {
115                expr: lift.original_expr.clone(),
116                alias: lift.alias.clone(),
117                leading_comments: vec![],
118                trailing_comment: None,
119            });
120        }
121    }
122
123    /// Create a CTE that includes the lifted expressions
124    pub fn create_lifting_cte(
125        &self,
126        base_table: &str,
127        lifted: &[LiftedInExpression],
128        cte_name: String,
129    ) -> CTE {
130        let mut select_items = vec![SelectItem::Star {
131            table_prefix: None,
132            leading_comments: vec![],
133            trailing_comment: None,
134        }];
135
136        // Add each lifted expression as a computed column
137        for lift in lifted {
138            select_items.push(SelectItem::Expression {
139                expr: lift.original_expr.clone(),
140                alias: lift.alias.clone(),
141                leading_comments: vec![],
142                trailing_comment: None,
143            });
144        }
145
146        let cte_select = SelectStatement {
147            distinct: false,
148            columns: vec!["*".to_string()],
149            select_items,
150            from_source: Some(TableSource::Table(base_table.to_string())),
151            #[allow(deprecated)]
152            from_table: Some(base_table.to_string()),
153            #[allow(deprecated)]
154            from_subquery: None,
155            #[allow(deprecated)]
156            from_function: None,
157            #[allow(deprecated)]
158            from_alias: None,
159            joins: Vec::new(),
160            where_clause: None,
161            order_by: None,
162            group_by: None,
163            having: None,
164            limit: None,
165            offset: None,
166            ctes: Vec::new(),
167            into_table: None,
168            set_operations: Vec::new(),
169            leading_comments: vec![],
170            trailing_comment: None,
171            qualify: None,
172        };
173
174        CTE {
175            name: cte_name,
176            column_list: None,
177            cte_type: CTEType::Standard(cte_select),
178        }
179    }
180
181    /// Rewrite a query to lift IN expressions with function calls
182    pub fn rewrite_query(&mut self, stmt: &mut SelectStatement) -> bool {
183        // Check if we have any IN expressions to lift
184        let has_in_to_lift = if let Some(ref where_clause) = stmt.where_clause {
185            where_clause
186                .conditions
187                .iter()
188                .any(|c| Self::needs_in_lifting(&c.expr))
189        } else {
190            false
191        };
192
193        if !has_in_to_lift {
194            return false;
195        }
196
197        // Extract the base table name
198        let base_table = match &stmt.from_table {
199            Some(table) => table.clone(),
200            None => return false, // Can't lift without a FROM clause
201        };
202
203        // Lift the IN expressions
204        let lifted = self.lift_in_expressions(stmt);
205
206        if lifted.is_empty() {
207            return false;
208        }
209
210        // Create a CTE with the lifted expressions
211        let cte_name = format!("{}_lifted", base_table);
212        let cte = self.create_lifting_cte(&base_table, &lifted, cte_name.clone());
213
214        // Add the CTE to the statement
215        stmt.ctes.push(cte);
216
217        // Update the FROM clause to use the CTE
218        stmt.from_table = Some(cte_name);
219
220        true
221    }
222}
223
224/// Information about a lifted IN expression
225#[derive(Debug, Clone)]
226pub struct LiftedInExpression {
227    /// The original expression (e.g., LOWER(column))
228    pub original_expr: SqlExpression,
229    /// The alias for the computed column
230    pub alias: String,
231    /// The values in the IN list
232    pub values: Vec<SqlExpression>,
233    /// Whether this was NOT IN (vs IN)
234    pub is_not_in: bool,
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn test_needs_in_lifting() {
243        // Simple column IN doesn't need lifting
244        let simple_in = SqlExpression::InList {
245            expr: Box::new(SqlExpression::Column(ColumnRef::unquoted(
246                "col".to_string(),
247            ))),
248            values: vec![SqlExpression::StringLiteral("a".to_string())],
249        };
250        assert!(!InOperatorLifter::needs_in_lifting(&simple_in));
251
252        // Function call IN needs lifting
253        let func_in = SqlExpression::InList {
254            expr: Box::new(SqlExpression::FunctionCall {
255                name: "LOWER".to_string(),
256                args: vec![SqlExpression::Column(ColumnRef::unquoted(
257                    "col".to_string(),
258                ))],
259                distinct: false,
260            }),
261            values: vec![SqlExpression::StringLiteral("a".to_string())],
262        };
263        assert!(InOperatorLifter::needs_in_lifting(&func_in));
264    }
265
266    #[test]
267    fn test_is_complex_expression() {
268        // Column is not complex
269        assert!(!InOperatorLifter::is_complex_expression(
270            &SqlExpression::Column(ColumnRef::unquoted("col".to_string()))
271        ));
272
273        // Function call is complex
274        assert!(InOperatorLifter::is_complex_expression(
275            &SqlExpression::FunctionCall {
276                name: "LOWER".to_string(),
277                args: vec![SqlExpression::Column(ColumnRef::unquoted(
278                    "col".to_string()
279                ))],
280                distinct: false,
281            }
282        ));
283
284        // Binary op is complex
285        assert!(InOperatorLifter::is_complex_expression(
286            &SqlExpression::BinaryOp {
287                left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
288                op: "+".to_string(),
289                right: Box::new(SqlExpression::NumberLiteral("1".to_string())),
290            }
291        ));
292    }
293}