sql_cli/query_plan/
cte_hoister.rs

1use crate::sql::parser::ast::{
2    CTEType, Condition, SelectItem, SelectStatement, SqlExpression, WhereClause, CTE,
3};
4use std::collections::{HashMap, HashSet};
5
6/// CTE Hoister - Analyzes and rewrites nested CTEs
7///
8/// This module implements automatic CTE hoisting to transform nested WITH clauses
9/// into a flat list of CTEs at the query's top level. This enables natural nested
10/// query writing while maintaining compatibility with SQL execution.
11///
12/// Example transformation:
13/// ```sql
14/// -- Input (nested):
15/// SELECT * FROM (
16///   WITH inner_cte AS (SELECT ...)
17///   SELECT * FROM inner_cte
18/// )
19///
20/// -- Output (hoisted):
21/// WITH inner_cte AS (SELECT ...)
22/// SELECT * FROM inner_cte
23/// ```
24pub struct CTEHoister {
25    hoisted_ctes: Vec<CTE>,
26    _cte_counter: usize,
27    dependency_graph: HashMap<String, HashSet<String>>,
28}
29
30impl CTEHoister {
31    pub fn new() -> Self {
32        Self {
33            hoisted_ctes: Vec::new(),
34            _cte_counter: 0,
35            dependency_graph: HashMap::new(),
36        }
37    }
38
39    /// Hoist all nested CTEs to the top level
40    pub fn hoist_ctes(mut statement: SelectStatement) -> SelectStatement {
41        let mut hoister = CTEHoister::new();
42
43        // First collect any existing top-level CTEs
44        for cte in statement.ctes.drain(..) {
45            hoister.add_cte(cte);
46        }
47
48        // Then recursively hoist from the main statement
49        let rewritten = hoister.hoist_from_statement(statement);
50
51        // Build final statement with all hoisted CTEs
52        SelectStatement {
53            ctes: hoister.get_ordered_ctes(),
54            ..rewritten
55        }
56    }
57
58    /// Recursively hoist CTEs from a SELECT statement
59    fn hoist_from_statement(&mut self, mut statement: SelectStatement) -> SelectStatement {
60        // Hoist from subquery in FROM clause
61        if let Some(subquery) = statement.from_subquery.take() {
62            let rewritten_sub = self.hoist_from_statement(*subquery);
63
64            // If the subquery has CTEs, hoist them
65            for cte in rewritten_sub.ctes.clone() {
66                self.add_cte(cte);
67            }
68
69            // Return the subquery without its CTEs (they're hoisted)
70            statement.from_subquery = Some(Box::new(SelectStatement {
71                ctes: Vec::new(),
72                ..rewritten_sub
73            }));
74        }
75
76        // Hoist from CTEs in this statement
77        let local_ctes = statement.ctes.drain(..).collect::<Vec<_>>();
78        for mut cte in local_ctes {
79            // First hoist from within this CTE's query if it's a standard CTE
80            if let CTEType::Standard(query) = cte.cte_type {
81                let hoisted_query = self.hoist_from_statement(query);
82                cte.cte_type = CTEType::Standard(hoisted_query);
83            }
84            // Then add the CTE itself
85            self.add_cte(cte);
86        }
87
88        // Hoist from expressions in SELECT items
89        statement.select_items = statement
90            .select_items
91            .into_iter()
92            .map(|item| self.hoist_from_select_item(item))
93            .collect();
94
95        // Hoist from WHERE clause subqueries
96        if let Some(where_clause) = &mut statement.where_clause {
97            self.hoist_from_where_clause(where_clause);
98        }
99
100        // Return the statement without CTEs (they're all hoisted)
101        SelectStatement {
102            ctes: Vec::new(),
103            ..statement
104        }
105    }
106
107    /// Hoist CTEs from a SELECT item (for subqueries in expressions)
108    fn hoist_from_select_item(&mut self, item: SelectItem) -> SelectItem {
109        match item {
110            SelectItem::Expression {
111                expr,
112                alias,
113                leading_comments,
114                trailing_comment,
115            } => SelectItem::Expression {
116                expr: self.hoist_from_expression(expr),
117                alias,
118                leading_comments,
119                trailing_comment,
120            },
121            other => other,
122        }
123    }
124
125    /// Hoist CTEs from an expression
126    fn hoist_from_expression(&mut self, expr: SqlExpression) -> SqlExpression {
127        match expr {
128            SqlExpression::ScalarSubquery { query } => {
129                let rewritten = self.hoist_from_statement(*query);
130                SqlExpression::ScalarSubquery {
131                    query: Box::new(rewritten),
132                }
133            }
134            SqlExpression::BinaryOp { left, op, right } => SqlExpression::BinaryOp {
135                left: Box::new(self.hoist_from_expression(*left)),
136                op,
137                right: Box::new(self.hoist_from_expression(*right)),
138            },
139            SqlExpression::FunctionCall {
140                name,
141                args,
142                distinct,
143            } => SqlExpression::FunctionCall {
144                name,
145                args: args
146                    .into_iter()
147                    .map(|arg| self.hoist_from_expression(arg))
148                    .collect(),
149                distinct,
150            },
151            SqlExpression::CaseExpression {
152                when_branches,
153                else_branch,
154            } => SqlExpression::CaseExpression {
155                when_branches: when_branches
156                    .into_iter()
157                    .map(|branch| crate::sql::parser::ast::WhenBranch {
158                        condition: Box::new(self.hoist_from_expression(*branch.condition)),
159                        result: Box::new(self.hoist_from_expression(*branch.result)),
160                    })
161                    .collect(),
162                else_branch: else_branch.map(|e| Box::new(self.hoist_from_expression(*e))),
163            },
164            SqlExpression::InList { expr, values } => SqlExpression::InList {
165                expr: Box::new(self.hoist_from_expression(*expr)),
166                values: values
167                    .into_iter()
168                    .map(|e| self.hoist_from_expression(e))
169                    .collect(),
170            },
171            SqlExpression::NotInList { expr, values } => SqlExpression::NotInList {
172                expr: Box::new(self.hoist_from_expression(*expr)),
173                values: values
174                    .into_iter()
175                    .map(|e| self.hoist_from_expression(e))
176                    .collect(),
177            },
178            SqlExpression::InSubquery { expr, subquery } => {
179                let rewritten = self.hoist_from_statement(*subquery);
180                SqlExpression::InSubquery {
181                    expr: Box::new(self.hoist_from_expression(*expr)),
182                    subquery: Box::new(rewritten),
183                }
184            }
185            SqlExpression::NotInSubquery { expr, subquery } => {
186                let rewritten = self.hoist_from_statement(*subquery);
187                SqlExpression::NotInSubquery {
188                    expr: Box::new(self.hoist_from_expression(*expr)),
189                    subquery: Box::new(rewritten),
190                }
191            }
192            SqlExpression::Between { expr, lower, upper } => SqlExpression::Between {
193                expr: Box::new(self.hoist_from_expression(*expr)),
194                lower: Box::new(self.hoist_from_expression(*lower)),
195                upper: Box::new(self.hoist_from_expression(*upper)),
196            },
197            SqlExpression::Not { expr } => SqlExpression::Not {
198                expr: Box::new(self.hoist_from_expression(*expr)),
199            },
200            // For other expression types that might contain subqueries
201            SqlExpression::SimpleCaseExpression {
202                expr,
203                when_branches,
204                else_branch,
205            } => SqlExpression::SimpleCaseExpression {
206                expr: Box::new(self.hoist_from_expression(*expr)),
207                when_branches: when_branches
208                    .into_iter()
209                    .map(|branch| crate::sql::parser::ast::SimpleWhenBranch {
210                        value: Box::new(self.hoist_from_expression(*branch.value)),
211                        result: Box::new(self.hoist_from_expression(*branch.result)),
212                    })
213                    .collect(),
214                else_branch: else_branch.map(|e| Box::new(self.hoist_from_expression(*e))),
215            },
216            // Terminal expressions don't contain subqueries
217            other => other,
218        }
219    }
220
221    /// Hoist CTEs from WHERE clause
222    fn hoist_from_where_clause(&mut self, where_clause: &mut WhereClause) {
223        for condition in &mut where_clause.conditions {
224            condition.expr = self.hoist_from_expression(condition.expr.clone());
225        }
226    }
227
228    /// Recursively hoist from a condition
229    fn hoist_from_condition(&mut self, condition: &mut Condition) {
230        condition.expr = self.hoist_from_expression(condition.expr.clone());
231    }
232
233    /// Add a CTE to the hoisted collection
234    fn add_cte(&mut self, cte: CTE) {
235        // Track dependencies for proper ordering
236        self.analyze_cte_dependencies(&cte);
237        self.hoisted_ctes.push(cte);
238    }
239
240    /// Analyze CTE dependencies for proper ordering
241    fn analyze_cte_dependencies(&mut self, cte: &CTE) {
242        let mut deps = HashSet::new();
243        if let CTEType::Standard(query) = &cte.cte_type {
244            self.find_cte_references(query, &mut deps);
245        }
246        self.dependency_graph.insert(cte.name.clone(), deps);
247    }
248
249    /// Find all CTE references in a statement
250    fn find_cte_references(&self, statement: &SelectStatement, deps: &mut HashSet<String>) {
251        // Check if FROM references a CTE
252        if let Some(table) = &statement.from_table {
253            // Check if this table name is a CTE
254            for cte in &self.hoisted_ctes {
255                if cte.name == *table {
256                    deps.insert(table.clone());
257                }
258            }
259        }
260
261        // Check subquery references
262        if let Some(subquery) = &statement.from_subquery {
263            self.find_cte_references(subquery, deps);
264        }
265
266        // Check JOIN references
267        for join in &statement.joins {
268            // Check if join table is a CTE
269            if let crate::sql::parser::ast::TableSource::Table(table_name) = &join.table {
270                for cte in &self.hoisted_ctes {
271                    if cte.name == *table_name {
272                        deps.insert(table_name.clone());
273                    }
274                }
275            }
276        }
277
278        // Check expressions for CTE references
279        for item in &statement.select_items {
280            if let SelectItem::Expression { expr, .. } = item {
281                self.find_cte_refs_in_expression(expr, deps);
282            }
283        }
284
285        // Check WHERE clause
286        if let Some(where_clause) = &statement.where_clause {
287            for condition in &where_clause.conditions {
288                self.find_cte_refs_in_expression(&condition.expr, deps);
289            }
290        }
291    }
292
293    /// Find CTE references in an expression
294    fn find_cte_refs_in_expression(&self, expr: &SqlExpression, deps: &mut HashSet<String>) {
295        match expr {
296            SqlExpression::ScalarSubquery { query } => {
297                self.find_cte_references(query, deps);
298            }
299            SqlExpression::InSubquery { subquery, .. } => {
300                self.find_cte_references(subquery, deps);
301            }
302            SqlExpression::NotInSubquery { subquery, .. } => {
303                self.find_cte_references(subquery, deps);
304            }
305            SqlExpression::FunctionCall { args, .. } => {
306                for arg in args {
307                    self.find_cte_refs_in_expression(arg, deps);
308                }
309            }
310            SqlExpression::BinaryOp { left, right, .. } => {
311                self.find_cte_refs_in_expression(left, deps);
312                self.find_cte_refs_in_expression(right, deps);
313            }
314            SqlExpression::CaseExpression {
315                when_branches,
316                else_branch,
317            } => {
318                for branch in when_branches {
319                    self.find_cte_refs_in_expression(&branch.condition, deps);
320                    self.find_cte_refs_in_expression(&branch.result, deps);
321                }
322                if let Some(else_expr) = else_branch {
323                    self.find_cte_refs_in_expression(else_expr, deps);
324                }
325            }
326            SqlExpression::SimpleCaseExpression {
327                expr,
328                when_branches,
329                else_branch,
330            } => {
331                self.find_cte_refs_in_expression(expr, deps);
332                for branch in when_branches {
333                    self.find_cte_refs_in_expression(&branch.value, deps);
334                    self.find_cte_refs_in_expression(&branch.result, deps);
335                }
336                if let Some(else_expr) = else_branch {
337                    self.find_cte_refs_in_expression(else_expr, deps);
338                }
339            }
340            SqlExpression::InList { expr, values } | SqlExpression::NotInList { expr, values } => {
341                self.find_cte_refs_in_expression(expr, deps);
342                for value in values {
343                    self.find_cte_refs_in_expression(value, deps);
344                }
345            }
346            SqlExpression::Between { expr, lower, upper } => {
347                self.find_cte_refs_in_expression(expr, deps);
348                self.find_cte_refs_in_expression(lower, deps);
349                self.find_cte_refs_in_expression(upper, deps);
350            }
351            SqlExpression::Not { expr } => {
352                self.find_cte_refs_in_expression(expr, deps);
353            }
354            _ => {}
355        }
356    }
357
358    /// Get CTEs in dependency order
359    fn get_ordered_ctes(self) -> Vec<CTE> {
360        // Simple topological sort
361        let mut result = Vec::new();
362        let mut visited = HashSet::new();
363        let mut temp_mark = HashSet::new();
364
365        fn visit(
366            name: &str,
367            graph: &HashMap<String, HashSet<String>>,
368            ctes: &[CTE],
369            visited: &mut HashSet<String>,
370            temp_mark: &mut HashSet<String>,
371            result: &mut Vec<CTE>,
372        ) {
373            if visited.contains(name) {
374                return;
375            }
376            if temp_mark.contains(name) {
377                // Circular dependency - for now just continue
378                return;
379            }
380
381            temp_mark.insert(name.to_string());
382
383            if let Some(deps) = graph.get(name) {
384                for dep in deps {
385                    visit(dep, graph, ctes, visited, temp_mark, result);
386                }
387            }
388
389            temp_mark.remove(name);
390            visited.insert(name.to_string());
391
392            // Find and add the CTE
393            if let Some(cte) = ctes.iter().find(|c| c.name == name) {
394                result.push(cte.clone());
395            }
396        }
397
398        // Visit all CTEs
399        for cte in &self.hoisted_ctes {
400            visit(
401                &cte.name,
402                &self.dependency_graph,
403                &self.hoisted_ctes,
404                &mut visited,
405                &mut temp_mark,
406                &mut result,
407            );
408        }
409
410        result
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417
418    #[test]
419    fn test_simple_cte_hoisting() {
420        // Test that a simple nested CTE gets hoisted
421        let inner_query = SelectStatement {
422            distinct: false,
423            columns: vec!["col1".to_string()],
424            select_items: vec![],
425            from_table: Some("table1".to_string()),
426            from_subquery: None,
427            from_function: None,
428            from_alias: None,
429            joins: vec![],
430            where_clause: None,
431            order_by: None,
432            group_by: None,
433            having: None,
434            limit: None,
435            offset: None,
436            ctes: vec![],
437            into_table: None,
438            set_operations: vec![],
439            leading_comments: vec![],
440            trailing_comment: None,
441        };
442
443        let nested_query = SelectStatement {
444            distinct: false,
445            columns: vec![],
446            select_items: vec![],
447            from_subquery: Some(Box::new(SelectStatement {
448                distinct: false,
449                columns: vec![],
450                select_items: vec![],
451                ctes: vec![CTE {
452                    name: "inner".to_string(),
453                    column_list: None,
454                    cte_type: CTEType::Standard(inner_query),
455                }],
456                from_table: Some("inner".to_string()),
457                from_subquery: None,
458                from_function: None,
459                from_alias: None,
460                joins: vec![],
461                where_clause: None,
462                order_by: None,
463                group_by: None,
464                having: None,
465                limit: None,
466                offset: None,
467                into_table: None,
468                set_operations: vec![],
469                leading_comments: vec![],
470                trailing_comment: None,
471            })),
472            from_table: None,
473            from_function: None,
474            from_alias: None,
475            joins: vec![],
476            where_clause: None,
477            order_by: None,
478            group_by: None,
479            having: None,
480            limit: None,
481            offset: None,
482            ctes: vec![],
483            into_table: None,
484            set_operations: vec![],
485            leading_comments: vec![],
486            trailing_comment: None,
487        };
488
489        let result = CTEHoister::hoist_ctes(nested_query);
490
491        assert_eq!(result.ctes.len(), 1);
492        assert_eq!(result.ctes[0].name, "inner");
493        assert!(result.from_subquery.as_ref().unwrap().ctes.is_empty());
494    }
495}