sql_cli/query_plan/
cte_hoister.rs

1use crate::sql::parser::ast::{
2    CTEType, Condition, SelectItem, SelectStatement, SqlExpression, WhereClause, CTE,
3};
4use std::collections::{HashMap, HashSet};
5
6/// CTE Hoister - Analyzes and rewrites nested CTEs
7///
8/// This module implements automatic CTE hoisting to transform nested WITH clauses
9/// into a flat list of CTEs at the query's top level. This enables natural nested
10/// query writing while maintaining compatibility with SQL execution.
11///
12/// Example transformation:
13/// ```sql
14/// -- Input (nested):
15/// SELECT * FROM (
16///   WITH inner_cte AS (SELECT ...)
17///   SELECT * FROM inner_cte
18/// )
19///
20/// -- Output (hoisted):
21/// WITH inner_cte AS (SELECT ...)
22/// SELECT * FROM inner_cte
23/// ```
24pub struct CTEHoister {
25    hoisted_ctes: Vec<CTE>,
26    cte_counter: usize,
27    dependency_graph: HashMap<String, HashSet<String>>,
28}
29
30impl CTEHoister {
31    pub fn new() -> Self {
32        Self {
33            hoisted_ctes: Vec::new(),
34            cte_counter: 0,
35            dependency_graph: HashMap::new(),
36        }
37    }
38
39    /// Hoist all nested CTEs to the top level
40    pub fn hoist_ctes(mut statement: SelectStatement) -> SelectStatement {
41        let mut hoister = CTEHoister::new();
42
43        // First collect any existing top-level CTEs
44        for cte in statement.ctes.drain(..) {
45            hoister.add_cte(cte);
46        }
47
48        // Then recursively hoist from the main statement
49        let rewritten = hoister.hoist_from_statement(statement);
50
51        // Build final statement with all hoisted CTEs
52        SelectStatement {
53            ctes: hoister.get_ordered_ctes(),
54            ..rewritten
55        }
56    }
57
58    /// Recursively hoist CTEs from a SELECT statement
59    fn hoist_from_statement(&mut self, mut statement: SelectStatement) -> SelectStatement {
60        // Hoist from subquery in FROM clause
61        if let Some(subquery) = statement.from_subquery.take() {
62            let rewritten_sub = self.hoist_from_statement(*subquery);
63
64            // If the subquery has CTEs, hoist them
65            for cte in rewritten_sub.ctes.clone() {
66                self.add_cte(cte);
67            }
68
69            // Return the subquery without its CTEs (they're hoisted)
70            statement.from_subquery = Some(Box::new(SelectStatement {
71                ctes: Vec::new(),
72                ..rewritten_sub
73            }));
74        }
75
76        // Hoist from CTEs in this statement
77        let local_ctes = statement.ctes.drain(..).collect::<Vec<_>>();
78        for mut cte in local_ctes {
79            // First hoist from within this CTE's query if it's a standard CTE
80            if let CTEType::Standard(query) = cte.cte_type {
81                let hoisted_query = self.hoist_from_statement(query);
82                cte.cte_type = CTEType::Standard(hoisted_query);
83            }
84            // Then add the CTE itself
85            self.add_cte(cte);
86        }
87
88        // Hoist from expressions in SELECT items
89        statement.select_items = statement
90            .select_items
91            .into_iter()
92            .map(|item| self.hoist_from_select_item(item))
93            .collect();
94
95        // Hoist from WHERE clause subqueries
96        if let Some(where_clause) = &mut statement.where_clause {
97            self.hoist_from_where_clause(where_clause);
98        }
99
100        // Return the statement without CTEs (they're all hoisted)
101        SelectStatement {
102            ctes: Vec::new(),
103            ..statement
104        }
105    }
106
107    /// Hoist CTEs from a SELECT item (for subqueries in expressions)
108    fn hoist_from_select_item(&mut self, item: SelectItem) -> SelectItem {
109        match item {
110            SelectItem::Expression { expr, alias } => SelectItem::Expression {
111                expr: self.hoist_from_expression(expr),
112                alias,
113            },
114            other => other,
115        }
116    }
117
118    /// Hoist CTEs from an expression
119    fn hoist_from_expression(&mut self, expr: SqlExpression) -> SqlExpression {
120        match expr {
121            SqlExpression::ScalarSubquery { query } => {
122                let rewritten = self.hoist_from_statement(*query);
123                SqlExpression::ScalarSubquery {
124                    query: Box::new(rewritten),
125                }
126            }
127            SqlExpression::BinaryOp { left, op, right } => SqlExpression::BinaryOp {
128                left: Box::new(self.hoist_from_expression(*left)),
129                op,
130                right: Box::new(self.hoist_from_expression(*right)),
131            },
132            SqlExpression::FunctionCall {
133                name,
134                args,
135                distinct,
136            } => SqlExpression::FunctionCall {
137                name,
138                args: args
139                    .into_iter()
140                    .map(|arg| self.hoist_from_expression(arg))
141                    .collect(),
142                distinct,
143            },
144            SqlExpression::CaseExpression {
145                when_branches,
146                else_branch,
147            } => SqlExpression::CaseExpression {
148                when_branches: when_branches
149                    .into_iter()
150                    .map(|branch| crate::sql::parser::ast::WhenBranch {
151                        condition: Box::new(self.hoist_from_expression(*branch.condition)),
152                        result: Box::new(self.hoist_from_expression(*branch.result)),
153                    })
154                    .collect(),
155                else_branch: else_branch.map(|e| Box::new(self.hoist_from_expression(*e))),
156            },
157            SqlExpression::InList { expr, values } => SqlExpression::InList {
158                expr: Box::new(self.hoist_from_expression(*expr)),
159                values: values
160                    .into_iter()
161                    .map(|e| self.hoist_from_expression(e))
162                    .collect(),
163            },
164            SqlExpression::NotInList { expr, values } => SqlExpression::NotInList {
165                expr: Box::new(self.hoist_from_expression(*expr)),
166                values: values
167                    .into_iter()
168                    .map(|e| self.hoist_from_expression(e))
169                    .collect(),
170            },
171            SqlExpression::InSubquery { expr, subquery } => {
172                let rewritten = self.hoist_from_statement(*subquery);
173                SqlExpression::InSubquery {
174                    expr: Box::new(self.hoist_from_expression(*expr)),
175                    subquery: Box::new(rewritten),
176                }
177            }
178            SqlExpression::NotInSubquery { expr, subquery } => {
179                let rewritten = self.hoist_from_statement(*subquery);
180                SqlExpression::NotInSubquery {
181                    expr: Box::new(self.hoist_from_expression(*expr)),
182                    subquery: Box::new(rewritten),
183                }
184            }
185            SqlExpression::Between { expr, lower, upper } => SqlExpression::Between {
186                expr: Box::new(self.hoist_from_expression(*expr)),
187                lower: Box::new(self.hoist_from_expression(*lower)),
188                upper: Box::new(self.hoist_from_expression(*upper)),
189            },
190            SqlExpression::Not { expr } => SqlExpression::Not {
191                expr: Box::new(self.hoist_from_expression(*expr)),
192            },
193            // For other expression types that might contain subqueries
194            SqlExpression::SimpleCaseExpression {
195                expr,
196                when_branches,
197                else_branch,
198            } => SqlExpression::SimpleCaseExpression {
199                expr: Box::new(self.hoist_from_expression(*expr)),
200                when_branches: when_branches
201                    .into_iter()
202                    .map(|branch| crate::sql::parser::ast::SimpleWhenBranch {
203                        value: Box::new(self.hoist_from_expression(*branch.value)),
204                        result: Box::new(self.hoist_from_expression(*branch.result)),
205                    })
206                    .collect(),
207                else_branch: else_branch.map(|e| Box::new(self.hoist_from_expression(*e))),
208            },
209            // Terminal expressions don't contain subqueries
210            other => other,
211        }
212    }
213
214    /// Hoist CTEs from WHERE clause
215    fn hoist_from_where_clause(&mut self, where_clause: &mut WhereClause) {
216        for condition in &mut where_clause.conditions {
217            condition.expr = self.hoist_from_expression(condition.expr.clone());
218        }
219    }
220
221    /// Recursively hoist from a condition
222    fn hoist_from_condition(&mut self, condition: &mut Condition) {
223        condition.expr = self.hoist_from_expression(condition.expr.clone());
224    }
225
226    /// Add a CTE to the hoisted collection
227    fn add_cte(&mut self, cte: CTE) {
228        // Track dependencies for proper ordering
229        self.analyze_cte_dependencies(&cte);
230        self.hoisted_ctes.push(cte);
231    }
232
233    /// Analyze CTE dependencies for proper ordering
234    fn analyze_cte_dependencies(&mut self, cte: &CTE) {
235        let mut deps = HashSet::new();
236        if let CTEType::Standard(query) = &cte.cte_type {
237            self.find_cte_references(query, &mut deps);
238        }
239        self.dependency_graph.insert(cte.name.clone(), deps);
240    }
241
242    /// Find all CTE references in a statement
243    fn find_cte_references(&self, statement: &SelectStatement, deps: &mut HashSet<String>) {
244        // Check if FROM references a CTE
245        if let Some(table) = &statement.from_table {
246            // Check if this table name is a CTE
247            for cte in &self.hoisted_ctes {
248                if cte.name == *table {
249                    deps.insert(table.clone());
250                }
251            }
252        }
253
254        // Check subquery references
255        if let Some(subquery) = &statement.from_subquery {
256            self.find_cte_references(subquery, deps);
257        }
258
259        // Check JOIN references
260        for join in &statement.joins {
261            // Check if join table is a CTE
262            if let crate::sql::parser::ast::TableSource::Table(table_name) = &join.table {
263                for cte in &self.hoisted_ctes {
264                    if cte.name == *table_name {
265                        deps.insert(table_name.clone());
266                    }
267                }
268            }
269        }
270
271        // Check expressions for CTE references
272        for item in &statement.select_items {
273            if let SelectItem::Expression { expr, .. } = item {
274                self.find_cte_refs_in_expression(expr, deps);
275            }
276        }
277
278        // Check WHERE clause
279        if let Some(where_clause) = &statement.where_clause {
280            for condition in &where_clause.conditions {
281                self.find_cte_refs_in_expression(&condition.expr, deps);
282            }
283        }
284    }
285
286    /// Find CTE references in an expression
287    fn find_cte_refs_in_expression(&self, expr: &SqlExpression, deps: &mut HashSet<String>) {
288        match expr {
289            SqlExpression::ScalarSubquery { query } => {
290                self.find_cte_references(query, deps);
291            }
292            SqlExpression::InSubquery { subquery, .. } => {
293                self.find_cte_references(subquery, deps);
294            }
295            SqlExpression::NotInSubquery { subquery, .. } => {
296                self.find_cte_references(subquery, deps);
297            }
298            SqlExpression::FunctionCall { args, .. } => {
299                for arg in args {
300                    self.find_cte_refs_in_expression(arg, deps);
301                }
302            }
303            SqlExpression::BinaryOp { left, right, .. } => {
304                self.find_cte_refs_in_expression(left, deps);
305                self.find_cte_refs_in_expression(right, deps);
306            }
307            SqlExpression::CaseExpression {
308                when_branches,
309                else_branch,
310            } => {
311                for branch in when_branches {
312                    self.find_cte_refs_in_expression(&branch.condition, deps);
313                    self.find_cte_refs_in_expression(&branch.result, deps);
314                }
315                if let Some(else_expr) = else_branch {
316                    self.find_cte_refs_in_expression(else_expr, deps);
317                }
318            }
319            SqlExpression::SimpleCaseExpression {
320                expr,
321                when_branches,
322                else_branch,
323            } => {
324                self.find_cte_refs_in_expression(expr, deps);
325                for branch in when_branches {
326                    self.find_cte_refs_in_expression(&branch.value, deps);
327                    self.find_cte_refs_in_expression(&branch.result, deps);
328                }
329                if let Some(else_expr) = else_branch {
330                    self.find_cte_refs_in_expression(else_expr, deps);
331                }
332            }
333            SqlExpression::InList { expr, values } | SqlExpression::NotInList { expr, values } => {
334                self.find_cte_refs_in_expression(expr, deps);
335                for value in values {
336                    self.find_cte_refs_in_expression(value, deps);
337                }
338            }
339            SqlExpression::Between { expr, lower, upper } => {
340                self.find_cte_refs_in_expression(expr, deps);
341                self.find_cte_refs_in_expression(lower, deps);
342                self.find_cte_refs_in_expression(upper, deps);
343            }
344            SqlExpression::Not { expr } => {
345                self.find_cte_refs_in_expression(expr, deps);
346            }
347            _ => {}
348        }
349    }
350
351    /// Get CTEs in dependency order
352    fn get_ordered_ctes(self) -> Vec<CTE> {
353        // Simple topological sort
354        let mut result = Vec::new();
355        let mut visited = HashSet::new();
356        let mut temp_mark = HashSet::new();
357
358        fn visit(
359            name: &str,
360            graph: &HashMap<String, HashSet<String>>,
361            ctes: &[CTE],
362            visited: &mut HashSet<String>,
363            temp_mark: &mut HashSet<String>,
364            result: &mut Vec<CTE>,
365        ) {
366            if visited.contains(name) {
367                return;
368            }
369            if temp_mark.contains(name) {
370                // Circular dependency - for now just continue
371                return;
372            }
373
374            temp_mark.insert(name.to_string());
375
376            if let Some(deps) = graph.get(name) {
377                for dep in deps {
378                    visit(dep, graph, ctes, visited, temp_mark, result);
379                }
380            }
381
382            temp_mark.remove(name);
383            visited.insert(name.to_string());
384
385            // Find and add the CTE
386            if let Some(cte) = ctes.iter().find(|c| c.name == name) {
387                result.push(cte.clone());
388            }
389        }
390
391        // Visit all CTEs
392        for cte in &self.hoisted_ctes {
393            visit(
394                &cte.name,
395                &self.dependency_graph,
396                &self.hoisted_ctes,
397                &mut visited,
398                &mut temp_mark,
399                &mut result,
400            );
401        }
402
403        result
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410
411    #[test]
412    fn test_simple_cte_hoisting() {
413        // Test that a simple nested CTE gets hoisted
414        let inner_query = SelectStatement {
415            distinct: false,
416            columns: vec!["col1".to_string()],
417            select_items: vec![],
418            from_table: Some("table1".to_string()),
419            from_subquery: None,
420            from_function: None,
421            from_alias: None,
422            joins: vec![],
423            where_clause: None,
424            order_by: None,
425            group_by: None,
426            having: None,
427            limit: None,
428            offset: None,
429            ctes: vec![],
430        };
431
432        let nested_query = SelectStatement {
433            distinct: false,
434            columns: vec![],
435            select_items: vec![],
436            from_subquery: Some(Box::new(SelectStatement {
437                distinct: false,
438                columns: vec![],
439                select_items: vec![],
440                ctes: vec![CTE {
441                    name: "inner".to_string(),
442                    column_list: None,
443                    cte_type: CTEType::Standard(inner_query),
444                }],
445                from_table: Some("inner".to_string()),
446                from_subquery: None,
447                from_function: None,
448                from_alias: None,
449                joins: vec![],
450                where_clause: None,
451                order_by: None,
452                group_by: None,
453                having: None,
454                limit: None,
455                offset: None,
456            })),
457            from_table: None,
458            from_function: None,
459            from_alias: None,
460            joins: vec![],
461            where_clause: None,
462            order_by: None,
463            group_by: None,
464            having: None,
465            limit: None,
466            offset: None,
467            ctes: vec![],
468        };
469
470        let result = CTEHoister::hoist_ctes(nested_query);
471
472        assert_eq!(result.ctes.len(), 1);
473        assert_eq!(result.ctes[0].name, "inner");
474        assert!(result.from_subquery.as_ref().unwrap().ctes.is_empty());
475    }
476}