sql_cli/query_plan/
in_operator_lifter.rs

1use crate::sql::parser::ast::{
2    CTEType, ColumnRef, Condition, SelectItem, SelectStatement, SqlExpression, CTE,
3};
4
5/// Specialized lifter for IN operator expressions with function calls
6pub struct InOperatorLifter {
7    /// Counter for generating unique column names
8    column_counter: usize,
9}
10
11impl InOperatorLifter {
12    pub fn new() -> Self {
13        InOperatorLifter { column_counter: 0 }
14    }
15
16    /// Generate a unique column name for lifted expressions
17    fn next_column_name(&mut self) -> String {
18        self.column_counter += 1;
19        format!("__expr_{}", self.column_counter)
20    }
21
22    /// Check if an IN expression needs lifting (has function call on left side)
23    pub fn needs_in_lifting(expr: &SqlExpression) -> bool {
24        match expr {
25            SqlExpression::InList { expr, .. } | SqlExpression::NotInList { expr, .. } => {
26                Self::is_complex_expression(expr)
27            }
28            _ => false,
29        }
30    }
31
32    /// Check if an expression is complex (not just a column reference)
33    fn is_complex_expression(expr: &SqlExpression) -> bool {
34        !matches!(expr, SqlExpression::Column(_))
35    }
36
37    /// Lift IN expressions with function calls
38    pub fn lift_in_expressions(&mut self, stmt: &mut SelectStatement) -> Vec<LiftedInExpression> {
39        let mut lifted = Vec::new();
40
41        if let Some(ref mut where_clause) = stmt.where_clause {
42            let mut new_conditions = Vec::new();
43
44            for condition in &where_clause.conditions {
45                match &condition.expr {
46                    SqlExpression::InList { expr, values } if Self::is_complex_expression(expr) => {
47                        let column_alias = self.next_column_name();
48
49                        // Record what we're lifting
50                        lifted.push(LiftedInExpression {
51                            original_expr: expr.as_ref().clone(),
52                            alias: column_alias.clone(),
53                            values: values.clone(),
54                            is_not_in: false,
55                        });
56
57                        // Create new simple condition
58                        new_conditions.push(Condition {
59                            expr: SqlExpression::InList {
60                                expr: Box::new(SqlExpression::Column(ColumnRef::unquoted(
61                                    column_alias,
62                                ))),
63                                values: values.clone(),
64                            },
65                            connector: condition.connector.clone(),
66                        });
67                    }
68                    SqlExpression::NotInList { expr, values }
69                        if Self::is_complex_expression(expr) =>
70                    {
71                        let column_alias = self.next_column_name();
72
73                        // Record what we're lifting
74                        lifted.push(LiftedInExpression {
75                            original_expr: expr.as_ref().clone(),
76                            alias: column_alias.clone(),
77                            values: values.clone(),
78                            is_not_in: true,
79                        });
80
81                        // Create new simple condition
82                        new_conditions.push(Condition {
83                            expr: SqlExpression::NotInList {
84                                expr: Box::new(SqlExpression::Column(ColumnRef::unquoted(
85                                    column_alias,
86                                ))),
87                                values: values.clone(),
88                            },
89                            connector: condition.connector.clone(),
90                        });
91                    }
92                    _ => {
93                        // Keep condition as-is
94                        new_conditions.push(condition.clone());
95                    }
96                }
97            }
98
99            // Update WHERE clause with simplified conditions
100            where_clause.conditions = new_conditions;
101        }
102
103        lifted
104    }
105
106    /// Apply lifted expressions to SELECT items
107    pub fn apply_lifted_to_select(
108        &self,
109        stmt: &mut SelectStatement,
110        lifted: &[LiftedInExpression],
111    ) {
112        // Add the computed expressions to the SELECT list
113        for lift in lifted {
114            stmt.select_items.push(SelectItem::Expression {
115                expr: lift.original_expr.clone(),
116                alias: lift.alias.clone(),
117            });
118        }
119    }
120
121    /// Create a CTE that includes the lifted expressions
122    pub fn create_lifting_cte(
123        &self,
124        base_table: &str,
125        lifted: &[LiftedInExpression],
126        cte_name: String,
127    ) -> CTE {
128        let mut select_items = vec![SelectItem::Star];
129
130        // Add each lifted expression as a computed column
131        for lift in lifted {
132            select_items.push(SelectItem::Expression {
133                expr: lift.original_expr.clone(),
134                alias: lift.alias.clone(),
135            });
136        }
137
138        let cte_select = SelectStatement {
139            distinct: false,
140            columns: vec!["*".to_string()],
141            select_items,
142            from_table: Some(base_table.to_string()),
143            from_subquery: None,
144            from_function: None,
145            from_alias: None,
146            joins: Vec::new(),
147            where_clause: None,
148            order_by: None,
149            group_by: None,
150            having: None,
151            limit: None,
152            offset: None,
153            ctes: Vec::new(),
154        };
155
156        CTE {
157            name: cte_name,
158            column_list: None,
159            cte_type: CTEType::Standard(cte_select),
160        }
161    }
162
163    /// Rewrite a query to lift IN expressions with function calls
164    pub fn rewrite_query(&mut self, stmt: &mut SelectStatement) -> bool {
165        // Check if we have any IN expressions to lift
166        let has_in_to_lift = if let Some(ref where_clause) = stmt.where_clause {
167            where_clause
168                .conditions
169                .iter()
170                .any(|c| Self::needs_in_lifting(&c.expr))
171        } else {
172            false
173        };
174
175        if !has_in_to_lift {
176            return false;
177        }
178
179        // Extract the base table name
180        let base_table = match &stmt.from_table {
181            Some(table) => table.clone(),
182            None => return false, // Can't lift without a FROM clause
183        };
184
185        // Lift the IN expressions
186        let lifted = self.lift_in_expressions(stmt);
187
188        if lifted.is_empty() {
189            return false;
190        }
191
192        // Create a CTE with the lifted expressions
193        let cte_name = format!("{}_lifted", base_table);
194        let cte = self.create_lifting_cte(&base_table, &lifted, cte_name.clone());
195
196        // Add the CTE to the statement
197        stmt.ctes.push(cte);
198
199        // Update the FROM clause to use the CTE
200        stmt.from_table = Some(cte_name);
201
202        true
203    }
204}
205
206/// Information about a lifted IN expression
207#[derive(Debug, Clone)]
208pub struct LiftedInExpression {
209    /// The original expression (e.g., LOWER(column))
210    pub original_expr: SqlExpression,
211    /// The alias for the computed column
212    pub alias: String,
213    /// The values in the IN list
214    pub values: Vec<SqlExpression>,
215    /// Whether this was NOT IN (vs IN)
216    pub is_not_in: bool,
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn test_needs_in_lifting() {
225        // Simple column IN doesn't need lifting
226        let simple_in = SqlExpression::InList {
227            expr: Box::new(SqlExpression::Column(ColumnRef::unquoted(
228                "col".to_string(),
229            ))),
230            values: vec![SqlExpression::StringLiteral("a".to_string())],
231        };
232        assert!(!InOperatorLifter::needs_in_lifting(&simple_in));
233
234        // Function call IN needs lifting
235        let func_in = SqlExpression::InList {
236            expr: Box::new(SqlExpression::FunctionCall {
237                name: "LOWER".to_string(),
238                args: vec![SqlExpression::Column(ColumnRef::unquoted(
239                    "col".to_string(),
240                ))],
241                distinct: false,
242            }),
243            values: vec![SqlExpression::StringLiteral("a".to_string())],
244        };
245        assert!(InOperatorLifter::needs_in_lifting(&func_in));
246    }
247
248    #[test]
249    fn test_is_complex_expression() {
250        // Column is not complex
251        assert!(!InOperatorLifter::is_complex_expression(
252            &SqlExpression::Column(ColumnRef::unquoted("col".to_string()))
253        ));
254
255        // Function call is complex
256        assert!(InOperatorLifter::is_complex_expression(
257            &SqlExpression::FunctionCall {
258                name: "LOWER".to_string(),
259                args: vec![SqlExpression::Column(ColumnRef::unquoted(
260                    "col".to_string()
261                ))],
262                distinct: false,
263            }
264        ));
265
266        // Binary op is complex
267        assert!(InOperatorLifter::is_complex_expression(
268            &SqlExpression::BinaryOp {
269                left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
270                op: "+".to_string(),
271                right: Box::new(SqlExpression::NumberLiteral("1".to_string())),
272            }
273        ));
274    }
275}