1use crate::sql::parser::ast::{
2 CTEType, Condition, SelectItem, SelectStatement, SqlExpression, WhereClause, CTE,
3};
4use std::collections::{HashMap, HashSet};
5
6pub struct CTEHoister {
25 hoisted_ctes: Vec<CTE>,
26 cte_counter: usize,
27 dependency_graph: HashMap<String, HashSet<String>>,
28}
29
30impl CTEHoister {
31 pub fn new() -> Self {
32 Self {
33 hoisted_ctes: Vec::new(),
34 cte_counter: 0,
35 dependency_graph: HashMap::new(),
36 }
37 }
38
39 pub fn hoist_ctes(mut statement: SelectStatement) -> SelectStatement {
41 let mut hoister = CTEHoister::new();
42
43 for cte in statement.ctes.drain(..) {
45 hoister.add_cte(cte);
46 }
47
48 let rewritten = hoister.hoist_from_statement(statement);
50
51 SelectStatement {
53 ctes: hoister.get_ordered_ctes(),
54 ..rewritten
55 }
56 }
57
58 fn hoist_from_statement(&mut self, mut statement: SelectStatement) -> SelectStatement {
60 if let Some(subquery) = statement.from_subquery.take() {
62 let rewritten_sub = self.hoist_from_statement(*subquery);
63
64 for cte in rewritten_sub.ctes.clone() {
66 self.add_cte(cte);
67 }
68
69 statement.from_subquery = Some(Box::new(SelectStatement {
71 ctes: Vec::new(),
72 ..rewritten_sub
73 }));
74 }
75
76 let local_ctes = statement.ctes.drain(..).collect::<Vec<_>>();
78 for mut cte in local_ctes {
79 if let CTEType::Standard(query) = cte.cte_type {
81 let hoisted_query = self.hoist_from_statement(query);
82 cte.cte_type = CTEType::Standard(hoisted_query);
83 }
84 self.add_cte(cte);
86 }
87
88 statement.select_items = statement
90 .select_items
91 .into_iter()
92 .map(|item| self.hoist_from_select_item(item))
93 .collect();
94
95 if let Some(where_clause) = &mut statement.where_clause {
97 self.hoist_from_where_clause(where_clause);
98 }
99
100 SelectStatement {
102 ctes: Vec::new(),
103 ..statement
104 }
105 }
106
107 fn hoist_from_select_item(&mut self, item: SelectItem) -> SelectItem {
109 match item {
110 SelectItem::Expression { expr, alias } => SelectItem::Expression {
111 expr: self.hoist_from_expression(expr),
112 alias,
113 },
114 other => other,
115 }
116 }
117
118 fn hoist_from_expression(&mut self, expr: SqlExpression) -> SqlExpression {
120 match expr {
121 SqlExpression::ScalarSubquery { query } => {
122 let rewritten = self.hoist_from_statement(*query);
123 SqlExpression::ScalarSubquery {
124 query: Box::new(rewritten),
125 }
126 }
127 SqlExpression::BinaryOp { left, op, right } => SqlExpression::BinaryOp {
128 left: Box::new(self.hoist_from_expression(*left)),
129 op,
130 right: Box::new(self.hoist_from_expression(*right)),
131 },
132 SqlExpression::FunctionCall {
133 name,
134 args,
135 distinct,
136 } => SqlExpression::FunctionCall {
137 name,
138 args: args
139 .into_iter()
140 .map(|arg| self.hoist_from_expression(arg))
141 .collect(),
142 distinct,
143 },
144 SqlExpression::CaseExpression {
145 when_branches,
146 else_branch,
147 } => SqlExpression::CaseExpression {
148 when_branches: when_branches
149 .into_iter()
150 .map(|branch| crate::sql::parser::ast::WhenBranch {
151 condition: Box::new(self.hoist_from_expression(*branch.condition)),
152 result: Box::new(self.hoist_from_expression(*branch.result)),
153 })
154 .collect(),
155 else_branch: else_branch.map(|e| Box::new(self.hoist_from_expression(*e))),
156 },
157 SqlExpression::InList { expr, values } => SqlExpression::InList {
158 expr: Box::new(self.hoist_from_expression(*expr)),
159 values: values
160 .into_iter()
161 .map(|e| self.hoist_from_expression(e))
162 .collect(),
163 },
164 SqlExpression::NotInList { expr, values } => SqlExpression::NotInList {
165 expr: Box::new(self.hoist_from_expression(*expr)),
166 values: values
167 .into_iter()
168 .map(|e| self.hoist_from_expression(e))
169 .collect(),
170 },
171 SqlExpression::InSubquery { expr, subquery } => {
172 let rewritten = self.hoist_from_statement(*subquery);
173 SqlExpression::InSubquery {
174 expr: Box::new(self.hoist_from_expression(*expr)),
175 subquery: Box::new(rewritten),
176 }
177 }
178 SqlExpression::NotInSubquery { expr, subquery } => {
179 let rewritten = self.hoist_from_statement(*subquery);
180 SqlExpression::NotInSubquery {
181 expr: Box::new(self.hoist_from_expression(*expr)),
182 subquery: Box::new(rewritten),
183 }
184 }
185 SqlExpression::Between { expr, lower, upper } => SqlExpression::Between {
186 expr: Box::new(self.hoist_from_expression(*expr)),
187 lower: Box::new(self.hoist_from_expression(*lower)),
188 upper: Box::new(self.hoist_from_expression(*upper)),
189 },
190 SqlExpression::Not { expr } => SqlExpression::Not {
191 expr: Box::new(self.hoist_from_expression(*expr)),
192 },
193 SqlExpression::SimpleCaseExpression {
195 expr,
196 when_branches,
197 else_branch,
198 } => SqlExpression::SimpleCaseExpression {
199 expr: Box::new(self.hoist_from_expression(*expr)),
200 when_branches: when_branches
201 .into_iter()
202 .map(|branch| crate::sql::parser::ast::SimpleWhenBranch {
203 value: Box::new(self.hoist_from_expression(*branch.value)),
204 result: Box::new(self.hoist_from_expression(*branch.result)),
205 })
206 .collect(),
207 else_branch: else_branch.map(|e| Box::new(self.hoist_from_expression(*e))),
208 },
209 other => other,
211 }
212 }
213
214 fn hoist_from_where_clause(&mut self, where_clause: &mut WhereClause) {
216 for condition in &mut where_clause.conditions {
217 condition.expr = self.hoist_from_expression(condition.expr.clone());
218 }
219 }
220
221 fn hoist_from_condition(&mut self, condition: &mut Condition) {
223 condition.expr = self.hoist_from_expression(condition.expr.clone());
224 }
225
226 fn add_cte(&mut self, cte: CTE) {
228 self.analyze_cte_dependencies(&cte);
230 self.hoisted_ctes.push(cte);
231 }
232
233 fn analyze_cte_dependencies(&mut self, cte: &CTE) {
235 let mut deps = HashSet::new();
236 if let CTEType::Standard(query) = &cte.cte_type {
237 self.find_cte_references(query, &mut deps);
238 }
239 self.dependency_graph.insert(cte.name.clone(), deps);
240 }
241
242 fn find_cte_references(&self, statement: &SelectStatement, deps: &mut HashSet<String>) {
244 if let Some(table) = &statement.from_table {
246 for cte in &self.hoisted_ctes {
248 if cte.name == *table {
249 deps.insert(table.clone());
250 }
251 }
252 }
253
254 if let Some(subquery) = &statement.from_subquery {
256 self.find_cte_references(subquery, deps);
257 }
258
259 for join in &statement.joins {
261 if let crate::sql::parser::ast::TableSource::Table(table_name) = &join.table {
263 for cte in &self.hoisted_ctes {
264 if cte.name == *table_name {
265 deps.insert(table_name.clone());
266 }
267 }
268 }
269 }
270
271 for item in &statement.select_items {
273 if let SelectItem::Expression { expr, .. } = item {
274 self.find_cte_refs_in_expression(expr, deps);
275 }
276 }
277
278 if let Some(where_clause) = &statement.where_clause {
280 for condition in &where_clause.conditions {
281 self.find_cte_refs_in_expression(&condition.expr, deps);
282 }
283 }
284 }
285
286 fn find_cte_refs_in_expression(&self, expr: &SqlExpression, deps: &mut HashSet<String>) {
288 match expr {
289 SqlExpression::ScalarSubquery { query } => {
290 self.find_cte_references(query, deps);
291 }
292 SqlExpression::InSubquery { subquery, .. } => {
293 self.find_cte_references(subquery, deps);
294 }
295 SqlExpression::NotInSubquery { subquery, .. } => {
296 self.find_cte_references(subquery, deps);
297 }
298 SqlExpression::FunctionCall { args, .. } => {
299 for arg in args {
300 self.find_cte_refs_in_expression(arg, deps);
301 }
302 }
303 SqlExpression::BinaryOp { left, right, .. } => {
304 self.find_cte_refs_in_expression(left, deps);
305 self.find_cte_refs_in_expression(right, deps);
306 }
307 SqlExpression::CaseExpression {
308 when_branches,
309 else_branch,
310 } => {
311 for branch in when_branches {
312 self.find_cte_refs_in_expression(&branch.condition, deps);
313 self.find_cte_refs_in_expression(&branch.result, deps);
314 }
315 if let Some(else_expr) = else_branch {
316 self.find_cte_refs_in_expression(else_expr, deps);
317 }
318 }
319 SqlExpression::SimpleCaseExpression {
320 expr,
321 when_branches,
322 else_branch,
323 } => {
324 self.find_cte_refs_in_expression(expr, deps);
325 for branch in when_branches {
326 self.find_cte_refs_in_expression(&branch.value, deps);
327 self.find_cte_refs_in_expression(&branch.result, deps);
328 }
329 if let Some(else_expr) = else_branch {
330 self.find_cte_refs_in_expression(else_expr, deps);
331 }
332 }
333 SqlExpression::InList { expr, values } | SqlExpression::NotInList { expr, values } => {
334 self.find_cte_refs_in_expression(expr, deps);
335 for value in values {
336 self.find_cte_refs_in_expression(value, deps);
337 }
338 }
339 SqlExpression::Between { expr, lower, upper } => {
340 self.find_cte_refs_in_expression(expr, deps);
341 self.find_cte_refs_in_expression(lower, deps);
342 self.find_cte_refs_in_expression(upper, deps);
343 }
344 SqlExpression::Not { expr } => {
345 self.find_cte_refs_in_expression(expr, deps);
346 }
347 _ => {}
348 }
349 }
350
351 fn get_ordered_ctes(self) -> Vec<CTE> {
353 let mut result = Vec::new();
355 let mut visited = HashSet::new();
356 let mut temp_mark = HashSet::new();
357
358 fn visit(
359 name: &str,
360 graph: &HashMap<String, HashSet<String>>,
361 ctes: &[CTE],
362 visited: &mut HashSet<String>,
363 temp_mark: &mut HashSet<String>,
364 result: &mut Vec<CTE>,
365 ) {
366 if visited.contains(name) {
367 return;
368 }
369 if temp_mark.contains(name) {
370 return;
372 }
373
374 temp_mark.insert(name.to_string());
375
376 if let Some(deps) = graph.get(name) {
377 for dep in deps {
378 visit(dep, graph, ctes, visited, temp_mark, result);
379 }
380 }
381
382 temp_mark.remove(name);
383 visited.insert(name.to_string());
384
385 if let Some(cte) = ctes.iter().find(|c| c.name == name) {
387 result.push(cte.clone());
388 }
389 }
390
391 for cte in &self.hoisted_ctes {
393 visit(
394 &cte.name,
395 &self.dependency_graph,
396 &self.hoisted_ctes,
397 &mut visited,
398 &mut temp_mark,
399 &mut result,
400 );
401 }
402
403 result
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410
411 #[test]
412 fn test_simple_cte_hoisting() {
413 let inner_query = SelectStatement {
415 distinct: false,
416 columns: vec!["col1".to_string()],
417 select_items: vec![],
418 from_table: Some("table1".to_string()),
419 from_subquery: None,
420 from_function: None,
421 from_alias: None,
422 joins: vec![],
423 where_clause: None,
424 order_by: None,
425 group_by: None,
426 having: None,
427 limit: None,
428 offset: None,
429 ctes: vec![],
430 };
431
432 let nested_query = SelectStatement {
433 distinct: false,
434 columns: vec![],
435 select_items: vec![],
436 from_subquery: Some(Box::new(SelectStatement {
437 distinct: false,
438 columns: vec![],
439 select_items: vec![],
440 ctes: vec![CTE {
441 name: "inner".to_string(),
442 column_list: None,
443 cte_type: CTEType::Standard(inner_query),
444 }],
445 from_table: Some("inner".to_string()),
446 from_subquery: None,
447 from_function: None,
448 from_alias: None,
449 joins: vec![],
450 where_clause: None,
451 order_by: None,
452 group_by: None,
453 having: None,
454 limit: None,
455 offset: None,
456 })),
457 from_table: None,
458 from_function: None,
459 from_alias: None,
460 joins: vec![],
461 where_clause: None,
462 order_by: None,
463 group_by: None,
464 having: None,
465 limit: None,
466 offset: None,
467 ctes: vec![],
468 };
469
470 let result = CTEHoister::hoist_ctes(nested_query);
471
472 assert_eq!(result.ctes.len(), 1);
473 assert_eq!(result.ctes[0].name, "inner");
474 assert!(result.from_subquery.as_ref().unwrap().ctes.is_empty());
475 }
476}