sql_cli/query_plan/
in_operator_lifter.rs

1use crate::sql::parser::ast::{
2    CTEType, Condition, SelectItem, SelectStatement, SqlExpression, CTE,
3};
4
5/// Specialized lifter for IN operator expressions with function calls
6pub struct InOperatorLifter {
7    /// Counter for generating unique column names
8    column_counter: usize,
9}
10
11impl InOperatorLifter {
12    pub fn new() -> Self {
13        InOperatorLifter { column_counter: 0 }
14    }
15
16    /// Generate a unique column name for lifted expressions
17    fn next_column_name(&mut self) -> String {
18        self.column_counter += 1;
19        format!("__expr_{}", self.column_counter)
20    }
21
22    /// Check if an IN expression needs lifting (has function call on left side)
23    pub fn needs_in_lifting(expr: &SqlExpression) -> bool {
24        match expr {
25            SqlExpression::InList { expr, .. } | SqlExpression::NotInList { expr, .. } => {
26                Self::is_complex_expression(expr)
27            }
28            _ => false,
29        }
30    }
31
32    /// Check if an expression is complex (not just a column reference)
33    fn is_complex_expression(expr: &SqlExpression) -> bool {
34        !matches!(expr, SqlExpression::Column(_))
35    }
36
37    /// Lift IN expressions with function calls
38    pub fn lift_in_expressions(&mut self, stmt: &mut SelectStatement) -> Vec<LiftedInExpression> {
39        let mut lifted = Vec::new();
40
41        if let Some(ref mut where_clause) = stmt.where_clause {
42            let mut new_conditions = Vec::new();
43
44            for condition in &where_clause.conditions {
45                match &condition.expr {
46                    SqlExpression::InList { expr, values } if Self::is_complex_expression(expr) => {
47                        let column_alias = self.next_column_name();
48
49                        // Record what we're lifting
50                        lifted.push(LiftedInExpression {
51                            original_expr: expr.as_ref().clone(),
52                            alias: column_alias.clone(),
53                            values: values.clone(),
54                            is_not_in: false,
55                        });
56
57                        // Create new simple condition
58                        new_conditions.push(Condition {
59                            expr: SqlExpression::InList {
60                                expr: Box::new(SqlExpression::Column(column_alias)),
61                                values: values.clone(),
62                            },
63                            connector: condition.connector.clone(),
64                        });
65                    }
66                    SqlExpression::NotInList { expr, values }
67                        if Self::is_complex_expression(expr) =>
68                    {
69                        let column_alias = self.next_column_name();
70
71                        // Record what we're lifting
72                        lifted.push(LiftedInExpression {
73                            original_expr: expr.as_ref().clone(),
74                            alias: column_alias.clone(),
75                            values: values.clone(),
76                            is_not_in: true,
77                        });
78
79                        // Create new simple condition
80                        new_conditions.push(Condition {
81                            expr: SqlExpression::NotInList {
82                                expr: Box::new(SqlExpression::Column(column_alias)),
83                                values: values.clone(),
84                            },
85                            connector: condition.connector.clone(),
86                        });
87                    }
88                    _ => {
89                        // Keep condition as-is
90                        new_conditions.push(condition.clone());
91                    }
92                }
93            }
94
95            // Update WHERE clause with simplified conditions
96            where_clause.conditions = new_conditions;
97        }
98
99        lifted
100    }
101
102    /// Apply lifted expressions to SELECT items
103    pub fn apply_lifted_to_select(
104        &self,
105        stmt: &mut SelectStatement,
106        lifted: &[LiftedInExpression],
107    ) {
108        // Add the computed expressions to the SELECT list
109        for lift in lifted {
110            stmt.select_items.push(SelectItem::Expression {
111                expr: lift.original_expr.clone(),
112                alias: lift.alias.clone(),
113            });
114        }
115    }
116
117    /// Create a CTE that includes the lifted expressions
118    pub fn create_lifting_cte(
119        &self,
120        base_table: &str,
121        lifted: &[LiftedInExpression],
122        cte_name: String,
123    ) -> CTE {
124        let mut select_items = vec![SelectItem::Star];
125
126        // Add each lifted expression as a computed column
127        for lift in lifted {
128            select_items.push(SelectItem::Expression {
129                expr: lift.original_expr.clone(),
130                alias: lift.alias.clone(),
131            });
132        }
133
134        let cte_select = SelectStatement {
135            distinct: false,
136            columns: vec!["*".to_string()],
137            select_items,
138            from_table: Some(base_table.to_string()),
139            from_subquery: None,
140            from_function: None,
141            from_alias: None,
142            joins: Vec::new(),
143            where_clause: None,
144            order_by: None,
145            group_by: None,
146            having: None,
147            limit: None,
148            offset: None,
149            ctes: Vec::new(),
150        };
151
152        CTE {
153            name: cte_name,
154            column_list: None,
155            cte_type: CTEType::Standard(cte_select),
156        }
157    }
158
159    /// Rewrite a query to lift IN expressions with function calls
160    pub fn rewrite_query(&mut self, stmt: &mut SelectStatement) -> bool {
161        // Check if we have any IN expressions to lift
162        let has_in_to_lift = if let Some(ref where_clause) = stmt.where_clause {
163            where_clause
164                .conditions
165                .iter()
166                .any(|c| Self::needs_in_lifting(&c.expr))
167        } else {
168            false
169        };
170
171        if !has_in_to_lift {
172            return false;
173        }
174
175        // Extract the base table name
176        let base_table = match &stmt.from_table {
177            Some(table) => table.clone(),
178            None => return false, // Can't lift without a FROM clause
179        };
180
181        // Lift the IN expressions
182        let lifted = self.lift_in_expressions(stmt);
183
184        if lifted.is_empty() {
185            return false;
186        }
187
188        // Create a CTE with the lifted expressions
189        let cte_name = format!("{}_lifted", base_table);
190        let cte = self.create_lifting_cte(&base_table, &lifted, cte_name.clone());
191
192        // Add the CTE to the statement
193        stmt.ctes.push(cte);
194
195        // Update the FROM clause to use the CTE
196        stmt.from_table = Some(cte_name);
197
198        true
199    }
200}
201
202/// Information about a lifted IN expression
203#[derive(Debug, Clone)]
204pub struct LiftedInExpression {
205    /// The original expression (e.g., LOWER(column))
206    pub original_expr: SqlExpression,
207    /// The alias for the computed column
208    pub alias: String,
209    /// The values in the IN list
210    pub values: Vec<SqlExpression>,
211    /// Whether this was NOT IN (vs IN)
212    pub is_not_in: bool,
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn test_needs_in_lifting() {
221        // Simple column IN doesn't need lifting
222        let simple_in = SqlExpression::InList {
223            expr: Box::new(SqlExpression::Column("col".to_string())),
224            values: vec![SqlExpression::StringLiteral("a".to_string())],
225        };
226        assert!(!InOperatorLifter::needs_in_lifting(&simple_in));
227
228        // Function call IN needs lifting
229        let func_in = SqlExpression::InList {
230            expr: Box::new(SqlExpression::FunctionCall {
231                name: "LOWER".to_string(),
232                args: vec![SqlExpression::Column("col".to_string())],
233                distinct: false,
234            }),
235            values: vec![SqlExpression::StringLiteral("a".to_string())],
236        };
237        assert!(InOperatorLifter::needs_in_lifting(&func_in));
238    }
239
240    #[test]
241    fn test_is_complex_expression() {
242        // Column is not complex
243        assert!(!InOperatorLifter::is_complex_expression(
244            &SqlExpression::Column("col".to_string())
245        ));
246
247        // Function call is complex
248        assert!(InOperatorLifter::is_complex_expression(
249            &SqlExpression::FunctionCall {
250                name: "LOWER".to_string(),
251                args: vec![SqlExpression::Column("col".to_string())],
252                distinct: false,
253            }
254        ));
255
256        // Binary op is complex
257        assert!(InOperatorLifter::is_complex_expression(
258            &SqlExpression::BinaryOp {
259                left: Box::new(SqlExpression::Column("a".to_string())),
260                op: "+".to_string(),
261                right: Box::new(SqlExpression::NumberLiteral("1".to_string())),
262            }
263        ));
264    }
265}