1use crate::sql::parser::ast::{
2    CTEType, ColumnRef, Condition, SelectItem, SelectStatement, SqlExpression, CTE,
3};
4
5pub struct InOperatorLifter {
7    column_counter: usize,
9}
10
11impl InOperatorLifter {
12    pub fn new() -> Self {
13        InOperatorLifter { column_counter: 0 }
14    }
15
16    fn next_column_name(&mut self) -> String {
18        self.column_counter += 1;
19        format!("__expr_{}", self.column_counter)
20    }
21
22    pub fn needs_in_lifting(expr: &SqlExpression) -> bool {
24        match expr {
25            SqlExpression::InList { expr, .. } | SqlExpression::NotInList { expr, .. } => {
26                Self::is_complex_expression(expr)
27            }
28            _ => false,
29        }
30    }
31
32    fn is_complex_expression(expr: &SqlExpression) -> bool {
34        !matches!(expr, SqlExpression::Column(_))
35    }
36
37    pub fn lift_in_expressions(&mut self, stmt: &mut SelectStatement) -> Vec<LiftedInExpression> {
39        let mut lifted = Vec::new();
40
41        if let Some(ref mut where_clause) = stmt.where_clause {
42            let mut new_conditions = Vec::new();
43
44            for condition in &where_clause.conditions {
45                match &condition.expr {
46                    SqlExpression::InList { expr, values } if Self::is_complex_expression(expr) => {
47                        let column_alias = self.next_column_name();
48
49                        lifted.push(LiftedInExpression {
51                            original_expr: expr.as_ref().clone(),
52                            alias: column_alias.clone(),
53                            values: values.clone(),
54                            is_not_in: false,
55                        });
56
57                        new_conditions.push(Condition {
59                            expr: SqlExpression::InList {
60                                expr: Box::new(SqlExpression::Column(ColumnRef::unquoted(
61                                    column_alias,
62                                ))),
63                                values: values.clone(),
64                            },
65                            connector: condition.connector.clone(),
66                        });
67                    }
68                    SqlExpression::NotInList { expr, values }
69                        if Self::is_complex_expression(expr) =>
70                    {
71                        let column_alias = self.next_column_name();
72
73                        lifted.push(LiftedInExpression {
75                            original_expr: expr.as_ref().clone(),
76                            alias: column_alias.clone(),
77                            values: values.clone(),
78                            is_not_in: true,
79                        });
80
81                        new_conditions.push(Condition {
83                            expr: SqlExpression::NotInList {
84                                expr: Box::new(SqlExpression::Column(ColumnRef::unquoted(
85                                    column_alias,
86                                ))),
87                                values: values.clone(),
88                            },
89                            connector: condition.connector.clone(),
90                        });
91                    }
92                    _ => {
93                        new_conditions.push(condition.clone());
95                    }
96                }
97            }
98
99            where_clause.conditions = new_conditions;
101        }
102
103        lifted
104    }
105
106    pub fn apply_lifted_to_select(
108        &self,
109        stmt: &mut SelectStatement,
110        lifted: &[LiftedInExpression],
111    ) {
112        for lift in lifted {
114            stmt.select_items.push(SelectItem::Expression {
115                expr: lift.original_expr.clone(),
116                alias: lift.alias.clone(),
117                leading_comments: vec![],
118                trailing_comment: None,
119            });
120        }
121    }
122
123    pub fn create_lifting_cte(
125        &self,
126        base_table: &str,
127        lifted: &[LiftedInExpression],
128        cte_name: String,
129    ) -> CTE {
130        let mut select_items = vec![SelectItem::Star {
131            table_prefix: None,
132            leading_comments: vec![],
133            trailing_comment: None,
134        }];
135
136        for lift in lifted {
138            select_items.push(SelectItem::Expression {
139                expr: lift.original_expr.clone(),
140                alias: lift.alias.clone(),
141                leading_comments: vec![],
142                trailing_comment: None,
143            });
144        }
145
146        let cte_select = SelectStatement {
147            distinct: false,
148            columns: vec!["*".to_string()],
149            select_items,
150            from_table: Some(base_table.to_string()),
151            from_subquery: None,
152            from_function: None,
153            from_alias: None,
154            joins: Vec::new(),
155            where_clause: None,
156            order_by: None,
157            group_by: None,
158            having: None,
159            limit: None,
160            offset: None,
161            ctes: Vec::new(),
162            into_table: None,
163            set_operations: Vec::new(),
164            leading_comments: vec![],
165            trailing_comment: None,
166        };
167
168        CTE {
169            name: cte_name,
170            column_list: None,
171            cte_type: CTEType::Standard(cte_select),
172        }
173    }
174
175    pub fn rewrite_query(&mut self, stmt: &mut SelectStatement) -> bool {
177        let has_in_to_lift = if let Some(ref where_clause) = stmt.where_clause {
179            where_clause
180                .conditions
181                .iter()
182                .any(|c| Self::needs_in_lifting(&c.expr))
183        } else {
184            false
185        };
186
187        if !has_in_to_lift {
188            return false;
189        }
190
191        let base_table = match &stmt.from_table {
193            Some(table) => table.clone(),
194            None => return false, };
196
197        let lifted = self.lift_in_expressions(stmt);
199
200        if lifted.is_empty() {
201            return false;
202        }
203
204        let cte_name = format!("{}_lifted", base_table);
206        let cte = self.create_lifting_cte(&base_table, &lifted, cte_name.clone());
207
208        stmt.ctes.push(cte);
210
211        stmt.from_table = Some(cte_name);
213
214        true
215    }
216}
217
218#[derive(Debug, Clone)]
220pub struct LiftedInExpression {
221    pub original_expr: SqlExpression,
223    pub alias: String,
225    pub values: Vec<SqlExpression>,
227    pub is_not_in: bool,
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    #[test]
236    fn test_needs_in_lifting() {
237        let simple_in = SqlExpression::InList {
239            expr: Box::new(SqlExpression::Column(ColumnRef::unquoted(
240                "col".to_string(),
241            ))),
242            values: vec![SqlExpression::StringLiteral("a".to_string())],
243        };
244        assert!(!InOperatorLifter::needs_in_lifting(&simple_in));
245
246        let func_in = SqlExpression::InList {
248            expr: Box::new(SqlExpression::FunctionCall {
249                name: "LOWER".to_string(),
250                args: vec![SqlExpression::Column(ColumnRef::unquoted(
251                    "col".to_string(),
252                ))],
253                distinct: false,
254            }),
255            values: vec![SqlExpression::StringLiteral("a".to_string())],
256        };
257        assert!(InOperatorLifter::needs_in_lifting(&func_in));
258    }
259
260    #[test]
261    fn test_is_complex_expression() {
262        assert!(!InOperatorLifter::is_complex_expression(
264            &SqlExpression::Column(ColumnRef::unquoted("col".to_string()))
265        ));
266
267        assert!(InOperatorLifter::is_complex_expression(
269            &SqlExpression::FunctionCall {
270                name: "LOWER".to_string(),
271                args: vec![SqlExpression::Column(ColumnRef::unquoted(
272                    "col".to_string()
273                ))],
274                distinct: false,
275            }
276        ));
277
278        assert!(InOperatorLifter::is_complex_expression(
280            &SqlExpression::BinaryOp {
281                left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
282                op: "+".to_string(),
283                right: Box::new(SqlExpression::NumberLiteral("1".to_string())),
284            }
285        ));
286    }
287}