1use crate::sql::parser::ast::{
2 CTEType, ColumnRef, Condition, SelectItem, SelectStatement, SqlExpression, TableSource, CTE,
3};
4
5pub struct InOperatorLifter {
7 column_counter: usize,
9}
10
11impl InOperatorLifter {
12 pub fn new() -> Self {
13 InOperatorLifter { column_counter: 0 }
14 }
15
16 fn next_column_name(&mut self) -> String {
18 self.column_counter += 1;
19 format!("__expr_{}", self.column_counter)
20 }
21
22 pub fn needs_in_lifting(expr: &SqlExpression) -> bool {
24 match expr {
25 SqlExpression::InList { expr, .. } | SqlExpression::NotInList { expr, .. } => {
26 Self::is_complex_expression(expr)
27 }
28 _ => false,
29 }
30 }
31
32 fn is_complex_expression(expr: &SqlExpression) -> bool {
34 !matches!(expr, SqlExpression::Column(_))
35 }
36
37 pub fn lift_in_expressions(&mut self, stmt: &mut SelectStatement) -> Vec<LiftedInExpression> {
39 let mut lifted = Vec::new();
40
41 if let Some(ref mut where_clause) = stmt.where_clause {
42 let mut new_conditions = Vec::new();
43
44 for condition in &where_clause.conditions {
45 match &condition.expr {
46 SqlExpression::InList { expr, values } if Self::is_complex_expression(expr) => {
47 let column_alias = self.next_column_name();
48
49 lifted.push(LiftedInExpression {
51 original_expr: expr.as_ref().clone(),
52 alias: column_alias.clone(),
53 values: values.clone(),
54 is_not_in: false,
55 });
56
57 new_conditions.push(Condition {
59 expr: SqlExpression::InList {
60 expr: Box::new(SqlExpression::Column(ColumnRef::unquoted(
61 column_alias,
62 ))),
63 values: values.clone(),
64 },
65 connector: condition.connector.clone(),
66 });
67 }
68 SqlExpression::NotInList { expr, values }
69 if Self::is_complex_expression(expr) =>
70 {
71 let column_alias = self.next_column_name();
72
73 lifted.push(LiftedInExpression {
75 original_expr: expr.as_ref().clone(),
76 alias: column_alias.clone(),
77 values: values.clone(),
78 is_not_in: true,
79 });
80
81 new_conditions.push(Condition {
83 expr: SqlExpression::NotInList {
84 expr: Box::new(SqlExpression::Column(ColumnRef::unquoted(
85 column_alias,
86 ))),
87 values: values.clone(),
88 },
89 connector: condition.connector.clone(),
90 });
91 }
92 _ => {
93 new_conditions.push(condition.clone());
95 }
96 }
97 }
98
99 where_clause.conditions = new_conditions;
101 }
102
103 lifted
104 }
105
106 pub fn apply_lifted_to_select(
108 &self,
109 stmt: &mut SelectStatement,
110 lifted: &[LiftedInExpression],
111 ) {
112 for lift in lifted {
114 stmt.select_items.push(SelectItem::Expression {
115 expr: lift.original_expr.clone(),
116 alias: lift.alias.clone(),
117 leading_comments: vec![],
118 trailing_comment: None,
119 });
120 }
121 }
122
123 pub fn create_lifting_cte(
125 &self,
126 base_table: &str,
127 lifted: &[LiftedInExpression],
128 cte_name: String,
129 ) -> CTE {
130 let mut select_items = vec![SelectItem::Star {
131 table_prefix: None,
132 leading_comments: vec![],
133 trailing_comment: None,
134 }];
135
136 for lift in lifted {
138 select_items.push(SelectItem::Expression {
139 expr: lift.original_expr.clone(),
140 alias: lift.alias.clone(),
141 leading_comments: vec![],
142 trailing_comment: None,
143 });
144 }
145
146 let cte_select = SelectStatement {
147 distinct: false,
148 columns: vec!["*".to_string()],
149 select_items,
150 from_source: Some(TableSource::Table(base_table.to_string())),
151 #[allow(deprecated)]
152 from_table: Some(base_table.to_string()),
153 #[allow(deprecated)]
154 from_subquery: None,
155 #[allow(deprecated)]
156 from_function: None,
157 #[allow(deprecated)]
158 from_alias: None,
159 joins: Vec::new(),
160 where_clause: None,
161 order_by: None,
162 group_by: None,
163 having: None,
164 limit: None,
165 offset: None,
166 ctes: Vec::new(),
167 into_table: None,
168 set_operations: Vec::new(),
169 leading_comments: vec![],
170 trailing_comment: None,
171 qualify: None,
172 };
173
174 CTE {
175 name: cte_name,
176 column_list: None,
177 cte_type: CTEType::Standard(cte_select),
178 }
179 }
180
181 pub fn rewrite_query(&mut self, stmt: &mut SelectStatement) -> bool {
183 let has_in_to_lift = if let Some(ref where_clause) = stmt.where_clause {
185 where_clause
186 .conditions
187 .iter()
188 .any(|c| Self::needs_in_lifting(&c.expr))
189 } else {
190 false
191 };
192
193 if !has_in_to_lift {
194 return false;
195 }
196
197 let base_table = match &stmt.from_table {
199 Some(table) => table.clone(),
200 None => return false, };
202
203 let lifted = self.lift_in_expressions(stmt);
205
206 if lifted.is_empty() {
207 return false;
208 }
209
210 let cte_name = format!("{}_lifted", base_table);
212 let cte = self.create_lifting_cte(&base_table, &lifted, cte_name.clone());
213
214 stmt.ctes.push(cte);
216
217 stmt.from_table = Some(cte_name);
219
220 true
221 }
222}
223
224#[derive(Debug, Clone)]
226pub struct LiftedInExpression {
227 pub original_expr: SqlExpression,
229 pub alias: String,
231 pub values: Vec<SqlExpression>,
233 pub is_not_in: bool,
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[test]
242 fn test_needs_in_lifting() {
243 let simple_in = SqlExpression::InList {
245 expr: Box::new(SqlExpression::Column(ColumnRef::unquoted(
246 "col".to_string(),
247 ))),
248 values: vec![SqlExpression::StringLiteral("a".to_string())],
249 };
250 assert!(!InOperatorLifter::needs_in_lifting(&simple_in));
251
252 let func_in = SqlExpression::InList {
254 expr: Box::new(SqlExpression::FunctionCall {
255 name: "LOWER".to_string(),
256 args: vec![SqlExpression::Column(ColumnRef::unquoted(
257 "col".to_string(),
258 ))],
259 distinct: false,
260 }),
261 values: vec![SqlExpression::StringLiteral("a".to_string())],
262 };
263 assert!(InOperatorLifter::needs_in_lifting(&func_in));
264 }
265
266 #[test]
267 fn test_is_complex_expression() {
268 assert!(!InOperatorLifter::is_complex_expression(
270 &SqlExpression::Column(ColumnRef::unquoted("col".to_string()))
271 ));
272
273 assert!(InOperatorLifter::is_complex_expression(
275 &SqlExpression::FunctionCall {
276 name: "LOWER".to_string(),
277 args: vec![SqlExpression::Column(ColumnRef::unquoted(
278 "col".to_string()
279 ))],
280 distinct: false,
281 }
282 ));
283
284 assert!(InOperatorLifter::is_complex_expression(
286 &SqlExpression::BinaryOp {
287 left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
288 op: "+".to_string(),
289 right: Box::new(SqlExpression::NumberLiteral("1".to_string())),
290 }
291 ));
292 }
293}