sql_cli/query_plan/
in_operator_lifter.rs1use crate::sql::parser::ast::{
2 CTEType, Condition, SelectItem, SelectStatement, SqlExpression, CTE,
3};
4
5pub struct InOperatorLifter {
7 column_counter: usize,
9}
10
11impl InOperatorLifter {
12 pub fn new() -> Self {
13 InOperatorLifter { column_counter: 0 }
14 }
15
16 fn next_column_name(&mut self) -> String {
18 self.column_counter += 1;
19 format!("__expr_{}", self.column_counter)
20 }
21
22 pub fn needs_in_lifting(expr: &SqlExpression) -> bool {
24 match expr {
25 SqlExpression::InList { expr, .. } | SqlExpression::NotInList { expr, .. } => {
26 Self::is_complex_expression(expr)
27 }
28 _ => false,
29 }
30 }
31
32 fn is_complex_expression(expr: &SqlExpression) -> bool {
34 !matches!(expr, SqlExpression::Column(_))
35 }
36
37 pub fn lift_in_expressions(&mut self, stmt: &mut SelectStatement) -> Vec<LiftedInExpression> {
39 let mut lifted = Vec::new();
40
41 if let Some(ref mut where_clause) = stmt.where_clause {
42 let mut new_conditions = Vec::new();
43
44 for condition in &where_clause.conditions {
45 match &condition.expr {
46 SqlExpression::InList { expr, values } if Self::is_complex_expression(expr) => {
47 let column_alias = self.next_column_name();
48
49 lifted.push(LiftedInExpression {
51 original_expr: expr.as_ref().clone(),
52 alias: column_alias.clone(),
53 values: values.clone(),
54 is_not_in: false,
55 });
56
57 new_conditions.push(Condition {
59 expr: SqlExpression::InList {
60 expr: Box::new(SqlExpression::Column(column_alias)),
61 values: values.clone(),
62 },
63 connector: condition.connector.clone(),
64 });
65 }
66 SqlExpression::NotInList { expr, values }
67 if Self::is_complex_expression(expr) =>
68 {
69 let column_alias = self.next_column_name();
70
71 lifted.push(LiftedInExpression {
73 original_expr: expr.as_ref().clone(),
74 alias: column_alias.clone(),
75 values: values.clone(),
76 is_not_in: true,
77 });
78
79 new_conditions.push(Condition {
81 expr: SqlExpression::NotInList {
82 expr: Box::new(SqlExpression::Column(column_alias)),
83 values: values.clone(),
84 },
85 connector: condition.connector.clone(),
86 });
87 }
88 _ => {
89 new_conditions.push(condition.clone());
91 }
92 }
93 }
94
95 where_clause.conditions = new_conditions;
97 }
98
99 lifted
100 }
101
102 pub fn apply_lifted_to_select(
104 &self,
105 stmt: &mut SelectStatement,
106 lifted: &[LiftedInExpression],
107 ) {
108 for lift in lifted {
110 stmt.select_items.push(SelectItem::Expression {
111 expr: lift.original_expr.clone(),
112 alias: lift.alias.clone(),
113 });
114 }
115 }
116
117 pub fn create_lifting_cte(
119 &self,
120 base_table: &str,
121 lifted: &[LiftedInExpression],
122 cte_name: String,
123 ) -> CTE {
124 let mut select_items = vec![SelectItem::Star];
125
126 for lift in lifted {
128 select_items.push(SelectItem::Expression {
129 expr: lift.original_expr.clone(),
130 alias: lift.alias.clone(),
131 });
132 }
133
134 let cte_select = SelectStatement {
135 distinct: false,
136 columns: vec!["*".to_string()],
137 select_items,
138 from_table: Some(base_table.to_string()),
139 from_subquery: None,
140 from_function: None,
141 from_alias: None,
142 joins: Vec::new(),
143 where_clause: None,
144 order_by: None,
145 group_by: None,
146 having: None,
147 limit: None,
148 offset: None,
149 ctes: Vec::new(),
150 };
151
152 CTE {
153 name: cte_name,
154 column_list: None,
155 cte_type: CTEType::Standard(cte_select),
156 }
157 }
158
159 pub fn rewrite_query(&mut self, stmt: &mut SelectStatement) -> bool {
161 let has_in_to_lift = if let Some(ref where_clause) = stmt.where_clause {
163 where_clause
164 .conditions
165 .iter()
166 .any(|c| Self::needs_in_lifting(&c.expr))
167 } else {
168 false
169 };
170
171 if !has_in_to_lift {
172 return false;
173 }
174
175 let base_table = match &stmt.from_table {
177 Some(table) => table.clone(),
178 None => return false, };
180
181 let lifted = self.lift_in_expressions(stmt);
183
184 if lifted.is_empty() {
185 return false;
186 }
187
188 let cte_name = format!("{}_lifted", base_table);
190 let cte = self.create_lifting_cte(&base_table, &lifted, cte_name.clone());
191
192 stmt.ctes.push(cte);
194
195 stmt.from_table = Some(cte_name);
197
198 true
199 }
200}
201
202#[derive(Debug, Clone)]
204pub struct LiftedInExpression {
205 pub original_expr: SqlExpression,
207 pub alias: String,
209 pub values: Vec<SqlExpression>,
211 pub is_not_in: bool,
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn test_needs_in_lifting() {
221 let simple_in = SqlExpression::InList {
223 expr: Box::new(SqlExpression::Column("col".to_string())),
224 values: vec![SqlExpression::StringLiteral("a".to_string())],
225 };
226 assert!(!InOperatorLifter::needs_in_lifting(&simple_in));
227
228 let func_in = SqlExpression::InList {
230 expr: Box::new(SqlExpression::FunctionCall {
231 name: "LOWER".to_string(),
232 args: vec![SqlExpression::Column("col".to_string())],
233 distinct: false,
234 }),
235 values: vec![SqlExpression::StringLiteral("a".to_string())],
236 };
237 assert!(InOperatorLifter::needs_in_lifting(&func_in));
238 }
239
240 #[test]
241 fn test_is_complex_expression() {
242 assert!(!InOperatorLifter::is_complex_expression(
244 &SqlExpression::Column("col".to_string())
245 ));
246
247 assert!(InOperatorLifter::is_complex_expression(
249 &SqlExpression::FunctionCall {
250 name: "LOWER".to_string(),
251 args: vec![SqlExpression::Column("col".to_string())],
252 distinct: false,
253 }
254 ));
255
256 assert!(InOperatorLifter::is_complex_expression(
258 &SqlExpression::BinaryOp {
259 left: Box::new(SqlExpression::Column("a".to_string())),
260 op: "+".to_string(),
261 right: Box::new(SqlExpression::NumberLiteral("1".to_string())),
262 }
263 ));
264 }
265}