1use crate::sql::parser::ast::{
2 CTEType, Condition, SelectItem, SelectStatement, SqlExpression, WhereClause, CTE,
3};
4use std::collections::{HashMap, HashSet};
5
6pub struct CTEHoister {
25 hoisted_ctes: Vec<CTE>,
26 _cte_counter: usize,
27 dependency_graph: HashMap<String, HashSet<String>>,
28}
29
30impl CTEHoister {
31 pub fn new() -> Self {
32 Self {
33 hoisted_ctes: Vec::new(),
34 _cte_counter: 0,
35 dependency_graph: HashMap::new(),
36 }
37 }
38
39 pub fn hoist_ctes(mut statement: SelectStatement) -> SelectStatement {
41 let mut hoister = CTEHoister::new();
42
43 for cte in statement.ctes.drain(..) {
45 hoister.add_cte(cte);
46 }
47
48 let rewritten = hoister.hoist_from_statement(statement);
50
51 SelectStatement {
53 ctes: hoister.get_ordered_ctes(),
54 ..rewritten
55 }
56 }
57
58 fn hoist_from_statement(&mut self, mut statement: SelectStatement) -> SelectStatement {
60 if let Some(subquery) = statement.from_subquery.take() {
62 let rewritten_sub = self.hoist_from_statement(*subquery);
63
64 for cte in rewritten_sub.ctes.clone() {
66 self.add_cte(cte);
67 }
68
69 statement.from_subquery = Some(Box::new(SelectStatement {
71 ctes: Vec::new(),
72 ..rewritten_sub
73 }));
74 }
75
76 let local_ctes = statement.ctes.drain(..).collect::<Vec<_>>();
78 for mut cte in local_ctes {
79 if let CTEType::Standard(query) = cte.cte_type {
81 let hoisted_query = self.hoist_from_statement(query);
82 cte.cte_type = CTEType::Standard(hoisted_query);
83 }
84 self.add_cte(cte);
86 }
87
88 statement.select_items = statement
90 .select_items
91 .into_iter()
92 .map(|item| self.hoist_from_select_item(item))
93 .collect();
94
95 if let Some(where_clause) = &mut statement.where_clause {
97 self.hoist_from_where_clause(where_clause);
98 }
99
100 SelectStatement {
102 ctes: Vec::new(),
103 ..statement
104 }
105 }
106
107 fn hoist_from_select_item(&mut self, item: SelectItem) -> SelectItem {
109 match item {
110 SelectItem::Expression {
111 expr,
112 alias,
113 leading_comments,
114 trailing_comment,
115 } => SelectItem::Expression {
116 expr: self.hoist_from_expression(expr),
117 alias,
118 leading_comments,
119 trailing_comment,
120 },
121 other => other,
122 }
123 }
124
125 fn hoist_from_expression(&mut self, expr: SqlExpression) -> SqlExpression {
127 match expr {
128 SqlExpression::ScalarSubquery { query } => {
129 let rewritten = self.hoist_from_statement(*query);
130 SqlExpression::ScalarSubquery {
131 query: Box::new(rewritten),
132 }
133 }
134 SqlExpression::BinaryOp { left, op, right } => SqlExpression::BinaryOp {
135 left: Box::new(self.hoist_from_expression(*left)),
136 op,
137 right: Box::new(self.hoist_from_expression(*right)),
138 },
139 SqlExpression::FunctionCall {
140 name,
141 args,
142 distinct,
143 } => SqlExpression::FunctionCall {
144 name,
145 args: args
146 .into_iter()
147 .map(|arg| self.hoist_from_expression(arg))
148 .collect(),
149 distinct,
150 },
151 SqlExpression::CaseExpression {
152 when_branches,
153 else_branch,
154 } => SqlExpression::CaseExpression {
155 when_branches: when_branches
156 .into_iter()
157 .map(|branch| crate::sql::parser::ast::WhenBranch {
158 condition: Box::new(self.hoist_from_expression(*branch.condition)),
159 result: Box::new(self.hoist_from_expression(*branch.result)),
160 })
161 .collect(),
162 else_branch: else_branch.map(|e| Box::new(self.hoist_from_expression(*e))),
163 },
164 SqlExpression::InList { expr, values } => SqlExpression::InList {
165 expr: Box::new(self.hoist_from_expression(*expr)),
166 values: values
167 .into_iter()
168 .map(|e| self.hoist_from_expression(e))
169 .collect(),
170 },
171 SqlExpression::NotInList { expr, values } => SqlExpression::NotInList {
172 expr: Box::new(self.hoist_from_expression(*expr)),
173 values: values
174 .into_iter()
175 .map(|e| self.hoist_from_expression(e))
176 .collect(),
177 },
178 SqlExpression::InSubquery { expr, subquery } => {
179 let rewritten = self.hoist_from_statement(*subquery);
180 SqlExpression::InSubquery {
181 expr: Box::new(self.hoist_from_expression(*expr)),
182 subquery: Box::new(rewritten),
183 }
184 }
185 SqlExpression::NotInSubquery { expr, subquery } => {
186 let rewritten = self.hoist_from_statement(*subquery);
187 SqlExpression::NotInSubquery {
188 expr: Box::new(self.hoist_from_expression(*expr)),
189 subquery: Box::new(rewritten),
190 }
191 }
192 SqlExpression::Between { expr, lower, upper } => SqlExpression::Between {
193 expr: Box::new(self.hoist_from_expression(*expr)),
194 lower: Box::new(self.hoist_from_expression(*lower)),
195 upper: Box::new(self.hoist_from_expression(*upper)),
196 },
197 SqlExpression::Not { expr } => SqlExpression::Not {
198 expr: Box::new(self.hoist_from_expression(*expr)),
199 },
200 SqlExpression::SimpleCaseExpression {
202 expr,
203 when_branches,
204 else_branch,
205 } => SqlExpression::SimpleCaseExpression {
206 expr: Box::new(self.hoist_from_expression(*expr)),
207 when_branches: when_branches
208 .into_iter()
209 .map(|branch| crate::sql::parser::ast::SimpleWhenBranch {
210 value: Box::new(self.hoist_from_expression(*branch.value)),
211 result: Box::new(self.hoist_from_expression(*branch.result)),
212 })
213 .collect(),
214 else_branch: else_branch.map(|e| Box::new(self.hoist_from_expression(*e))),
215 },
216 other => other,
218 }
219 }
220
221 fn hoist_from_where_clause(&mut self, where_clause: &mut WhereClause) {
223 for condition in &mut where_clause.conditions {
224 condition.expr = self.hoist_from_expression(condition.expr.clone());
225 }
226 }
227
228 fn hoist_from_condition(&mut self, condition: &mut Condition) {
230 condition.expr = self.hoist_from_expression(condition.expr.clone());
231 }
232
233 fn add_cte(&mut self, cte: CTE) {
235 self.analyze_cte_dependencies(&cte);
237 self.hoisted_ctes.push(cte);
238 }
239
240 fn analyze_cte_dependencies(&mut self, cte: &CTE) {
242 let mut deps = HashSet::new();
243 if let CTEType::Standard(query) = &cte.cte_type {
244 self.find_cte_references(query, &mut deps);
245 }
246 self.dependency_graph.insert(cte.name.clone(), deps);
247 }
248
249 fn find_cte_references(&self, statement: &SelectStatement, deps: &mut HashSet<String>) {
251 if let Some(table) = &statement.from_table {
253 for cte in &self.hoisted_ctes {
255 if cte.name == *table {
256 deps.insert(table.clone());
257 }
258 }
259 }
260
261 if let Some(subquery) = &statement.from_subquery {
263 self.find_cte_references(subquery, deps);
264 }
265
266 for join in &statement.joins {
268 if let crate::sql::parser::ast::TableSource::Table(table_name) = &join.table {
270 for cte in &self.hoisted_ctes {
271 if cte.name == *table_name {
272 deps.insert(table_name.clone());
273 }
274 }
275 }
276 }
277
278 for item in &statement.select_items {
280 if let SelectItem::Expression { expr, .. } = item {
281 self.find_cte_refs_in_expression(expr, deps);
282 }
283 }
284
285 if let Some(where_clause) = &statement.where_clause {
287 for condition in &where_clause.conditions {
288 self.find_cte_refs_in_expression(&condition.expr, deps);
289 }
290 }
291 }
292
293 fn find_cte_refs_in_expression(&self, expr: &SqlExpression, deps: &mut HashSet<String>) {
295 match expr {
296 SqlExpression::ScalarSubquery { query } => {
297 self.find_cte_references(query, deps);
298 }
299 SqlExpression::InSubquery { subquery, .. } => {
300 self.find_cte_references(subquery, deps);
301 }
302 SqlExpression::NotInSubquery { subquery, .. } => {
303 self.find_cte_references(subquery, deps);
304 }
305 SqlExpression::FunctionCall { args, .. } => {
306 for arg in args {
307 self.find_cte_refs_in_expression(arg, deps);
308 }
309 }
310 SqlExpression::BinaryOp { left, right, .. } => {
311 self.find_cte_refs_in_expression(left, deps);
312 self.find_cte_refs_in_expression(right, deps);
313 }
314 SqlExpression::CaseExpression {
315 when_branches,
316 else_branch,
317 } => {
318 for branch in when_branches {
319 self.find_cte_refs_in_expression(&branch.condition, deps);
320 self.find_cte_refs_in_expression(&branch.result, deps);
321 }
322 if let Some(else_expr) = else_branch {
323 self.find_cte_refs_in_expression(else_expr, deps);
324 }
325 }
326 SqlExpression::SimpleCaseExpression {
327 expr,
328 when_branches,
329 else_branch,
330 } => {
331 self.find_cte_refs_in_expression(expr, deps);
332 for branch in when_branches {
333 self.find_cte_refs_in_expression(&branch.value, deps);
334 self.find_cte_refs_in_expression(&branch.result, deps);
335 }
336 if let Some(else_expr) = else_branch {
337 self.find_cte_refs_in_expression(else_expr, deps);
338 }
339 }
340 SqlExpression::InList { expr, values } | SqlExpression::NotInList { expr, values } => {
341 self.find_cte_refs_in_expression(expr, deps);
342 for value in values {
343 self.find_cte_refs_in_expression(value, deps);
344 }
345 }
346 SqlExpression::Between { expr, lower, upper } => {
347 self.find_cte_refs_in_expression(expr, deps);
348 self.find_cte_refs_in_expression(lower, deps);
349 self.find_cte_refs_in_expression(upper, deps);
350 }
351 SqlExpression::Not { expr } => {
352 self.find_cte_refs_in_expression(expr, deps);
353 }
354 _ => {}
355 }
356 }
357
358 fn get_ordered_ctes(self) -> Vec<CTE> {
360 let mut result = Vec::new();
362 let mut visited = HashSet::new();
363 let mut temp_mark = HashSet::new();
364
365 fn visit(
366 name: &str,
367 graph: &HashMap<String, HashSet<String>>,
368 ctes: &[CTE],
369 visited: &mut HashSet<String>,
370 temp_mark: &mut HashSet<String>,
371 result: &mut Vec<CTE>,
372 ) {
373 if visited.contains(name) {
374 return;
375 }
376 if temp_mark.contains(name) {
377 return;
379 }
380
381 temp_mark.insert(name.to_string());
382
383 if let Some(deps) = graph.get(name) {
384 for dep in deps {
385 visit(dep, graph, ctes, visited, temp_mark, result);
386 }
387 }
388
389 temp_mark.remove(name);
390 visited.insert(name.to_string());
391
392 if let Some(cte) = ctes.iter().find(|c| c.name == name) {
394 result.push(cte.clone());
395 }
396 }
397
398 for cte in &self.hoisted_ctes {
400 visit(
401 &cte.name,
402 &self.dependency_graph,
403 &self.hoisted_ctes,
404 &mut visited,
405 &mut temp_mark,
406 &mut result,
407 );
408 }
409
410 result
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417
418 #[test]
419 fn test_simple_cte_hoisting() {
420 let inner_query = SelectStatement {
422 distinct: false,
423 columns: vec!["col1".to_string()],
424 select_items: vec![],
425 from_table: Some("table1".to_string()),
426 from_subquery: None,
427 from_function: None,
428 from_alias: None,
429 joins: vec![],
430 where_clause: None,
431 order_by: None,
432 group_by: None,
433 having: None,
434 limit: None,
435 offset: None,
436 ctes: vec![],
437 into_table: None,
438 set_operations: vec![],
439 leading_comments: vec![],
440 trailing_comment: None,
441 };
442
443 let nested_query = SelectStatement {
444 distinct: false,
445 columns: vec![],
446 select_items: vec![],
447 from_subquery: Some(Box::new(SelectStatement {
448 distinct: false,
449 columns: vec![],
450 select_items: vec![],
451 ctes: vec![CTE {
452 name: "inner".to_string(),
453 column_list: None,
454 cte_type: CTEType::Standard(inner_query),
455 }],
456 from_table: Some("inner".to_string()),
457 from_subquery: None,
458 from_function: None,
459 from_alias: None,
460 joins: vec![],
461 where_clause: None,
462 order_by: None,
463 group_by: None,
464 having: None,
465 limit: None,
466 offset: None,
467 into_table: None,
468 set_operations: vec![],
469 leading_comments: vec![],
470 trailing_comment: None,
471 })),
472 from_table: None,
473 from_function: None,
474 from_alias: None,
475 joins: vec![],
476 where_clause: None,
477 order_by: None,
478 group_by: None,
479 having: None,
480 limit: None,
481 offset: None,
482 ctes: vec![],
483 into_table: None,
484 set_operations: vec![],
485 leading_comments: vec![],
486 trailing_comment: None,
487 };
488
489 let result = CTEHoister::hoist_ctes(nested_query);
490
491 assert_eq!(result.ctes.len(), 1);
492 assert_eq!(result.ctes[0].name, "inner");
493 assert!(result.from_subquery.as_ref().unwrap().ctes.is_empty());
494 }
495}