vibesql_executor/select/
cte.rs

1//! Common Table Expression (CTE) handling for SELECT queries
2
3use std::{
4    collections::{HashMap, HashSet},
5    sync::Arc,
6};
7
8use crate::errors::ExecutorError;
9
10/// CTE result: (schema, shared rows)
11///
12/// Uses `Arc<Vec<Row>>` to enable O(1) cloning when CTEs are:
13/// - Propagated from outer queries to subqueries
14/// - Referenced multiple times without filtering
15///
16/// This avoids deep-cloning all rows on every CTE reference.
17pub type CteResult = (vibesql_catalog::TableSchema, Arc<Vec<vibesql_storage::Row>>);
18
19/// Execute all CTEs and return their results
20///
21/// CTEs are executed in order, allowing later CTEs to reference earlier ones.
22pub fn execute_ctes<F>(
23    ctes: &[vibesql_ast::CommonTableExpr],
24    executor: F,
25) -> Result<HashMap<String, CteResult>, ExecutorError>
26where
27    F: Fn(
28        &vibesql_ast::SelectStmt,
29        &HashMap<String, CteResult>,
30    ) -> Result<Vec<vibesql_storage::Row>, ExecutorError>,
31{
32    // Use the memory-tracking version with a no-op memory check
33    execute_ctes_with_memory_check(ctes, executor, |_| Ok(()))
34}
35
36/// Execute all CTEs with memory tracking
37///
38/// CTEs are executed in order, allowing later CTEs to reference earlier ones.
39/// After each CTE is materialized, the memory_check callback is called with
40/// the estimated size of the CTE result to enforce memory limits.
41pub(super) fn execute_ctes_with_memory_check<F, M>(
42    ctes: &[vibesql_ast::CommonTableExpr],
43    executor: F,
44    memory_check: M,
45) -> Result<HashMap<String, CteResult>, ExecutorError>
46where
47    F: Fn(
48        &vibesql_ast::SelectStmt,
49        &HashMap<String, CteResult>,
50    ) -> Result<Vec<vibesql_storage::Row>, ExecutorError>,
51    M: Fn(usize) -> Result<(), ExecutorError>,
52{
53    let mut cte_results = HashMap::new();
54
55    // Execute each CTE in order
56    // CTEs can reference previously defined CTEs
57    for cte in ctes {
58        // Check if this is a recursive CTE
59        // SQLite compatibility: auto-detect recursive CTEs even without RECURSIVE keyword
60        // A CTE is recursive if it references itself in a UNION/UNION ALL set operation
61        let is_recursive = cte.recursive || is_cte_self_referential(cte);
62        let rows = if is_recursive {
63            // Recursive CTE: execute base term, then iteratively execute recursive term
64            execute_recursive_cte(cte, &cte_results, &executor, &memory_check)?
65        } else {
66            // Non-recursive CTE: execute query directly
67            executor(&cte.query, &cte_results)?
68        };
69
70        // Track memory for this CTE result before storing
71        let estimated_size = super::helpers::estimate_result_size(&rows);
72        memory_check(estimated_size)?;
73
74        //  Determine the schema for this CTE
75        let schema = derive_cte_schema(cte, &rows)?;
76
77        // Store the CTE result wrapped in Arc for efficient sharing
78        cte_results.insert(cte.name.clone(), (schema, Arc::new(rows)));
79    }
80
81    Ok(cte_results)
82}
83
84/// Derive the schema for a CTE from its query and results
85pub(super) fn derive_cte_schema(
86    cte: &vibesql_ast::CommonTableExpr,
87    rows: &[vibesql_storage::Row],
88) -> Result<vibesql_catalog::TableSchema, ExecutorError> {
89    // If column names are explicitly specified, use those
90    if let Some(column_names) = &cte.columns {
91        // Get data types from first row (if available)
92        if let Some(first_row) = rows.first() {
93            if first_row.values.len() != column_names.len() {
94                return Err(ExecutorError::UnsupportedFeature(format!(
95                    "CTE column count mismatch: specified {} columns but query returned {}",
96                    column_names.len(),
97                    first_row.values.len()
98                )));
99            }
100
101            let columns = column_names
102                .iter()
103                .zip(&first_row.values)
104                .map(|(name, value)| {
105                    let data_type = infer_type_from_value(value);
106                    vibesql_catalog::ColumnSchema::new(name.clone(), data_type, true)
107                    // nullable for
108                    // simplicity
109                })
110                .collect();
111
112            Ok(vibesql_catalog::TableSchema::new(cte.name.clone(), columns))
113        } else {
114            // Empty result set - create schema with VARCHAR columns
115            let columns = column_names
116                .iter()
117                .map(|name| {
118                    vibesql_catalog::ColumnSchema::new(
119                        name.clone(),
120                        vibesql_types::DataType::Varchar { max_length: Some(255) },
121                        true,
122                    )
123                })
124                .collect();
125
126            Ok(vibesql_catalog::TableSchema::new(cte.name.clone(), columns))
127        }
128    } else {
129        // No explicit column names - infer from query SELECT list
130        // Extract column names from SELECT items
131        let columns = cte
132            .query
133            .select_list
134            .iter()
135            .enumerate()
136            .map(|(i, item)| {
137                // Infer data type from first row if available, otherwise use default
138                let data_type = if let Some(first_row) = rows.first() {
139                    infer_type_from_value(&first_row.values[i])
140                } else {
141                    // No rows - use default type (VARCHAR)
142                    vibesql_types::DataType::Varchar { max_length: Some(255) }
143                };
144
145                // Extract column name from SELECT item
146                let col_name = match item {
147                    vibesql_ast::SelectItem::Wildcard { .. }
148                    | vibesql_ast::SelectItem::QualifiedWildcard { .. } => format!("col{}", i),
149                    vibesql_ast::SelectItem::Expression { expr, alias, .. } => {
150                        if let Some(a) = alias {
151                            a.clone()
152                        } else {
153                            // Try to extract name from expression
154                            match expr {
155                                vibesql_ast::Expression::ColumnRef(col_id) => {
156                                    col_id.column_canonical().to_string()
157                                }
158                                _ => format!("col{}", i),
159                            }
160                        }
161                    }
162                };
163
164                vibesql_catalog::ColumnSchema::new(col_name, data_type, true) // nullable
165            })
166            .collect();
167
168        Ok(vibesql_catalog::TableSchema::new(cte.name.clone(), columns))
169    }
170}
171
172/// Execute a recursive CTE using iterative evaluation
173///
174/// Recursive CTEs in SQL:1999/SQLite are defined with UNION or UNION ALL:
175/// ```sql
176/// WITH RECURSIVE cte AS (
177///   base_query          -- Executed once to get initial rows
178///   UNION [ALL]
179///   recursive_query     -- References 'cte', executed iteratively
180/// )
181/// ```
182///
183/// Algorithm:
184/// 1. Split query into base and recursive terms (before/after UNION [ALL])
185/// 2. Execute base term to get initial working table
186/// 3. Repeat until no new rows or max depth reached:
187///    - Make working table available as CTE
188///    - Execute recursive term
189///    - Add new rows to result (with deduplication for UNION)
190///    - Update working table to new rows
191fn execute_recursive_cte<F, M>(
192    cte: &vibesql_ast::CommonTableExpr,
193    cte_results: &HashMap<String, CteResult>,
194    executor: &F,
195    memory_check: &M,
196) -> Result<Vec<vibesql_storage::Row>, ExecutorError>
197where
198    F: Fn(
199        &vibesql_ast::SelectStmt,
200        &HashMap<String, CteResult>,
201    ) -> Result<Vec<vibesql_storage::Row>, ExecutorError>,
202    M: Fn(usize) -> Result<(), ExecutorError>,
203{
204    use crate::limits::MAX_RECURSIVE_CTE_ITERATIONS;
205
206    // Validate that recursive CTE uses UNION ALL
207    let set_op = cte.query.set_operation.as_ref().ok_or_else(|| {
208        ExecutorError::UnsupportedFeature(format!(
209            "Recursive CTE '{}' must use UNION ALL",
210            cte.name
211        ))
212    })?;
213
214    if set_op.op != vibesql_ast::SetOperator::Union {
215        return Err(ExecutorError::UnsupportedFeature(format!(
216            "Recursive CTE '{}' must use UNION or UNION ALL (not INTERSECT or EXCEPT)",
217            cte.name
218        )));
219    }
220
221    // Extract base and recursive terms
222    // Base term: the main SELECT (before UNION [ALL])
223    // Recursive term: the right side of UNION [ALL]
224
225    // Create base-only query without the UNION ALL set operation
226    // This prevents the base term from trying to reference the CTE before it exists
227    let base_query = vibesql_ast::SelectStmt {
228        with_clause: cte.query.with_clause.clone(),
229        distinct: cte.query.distinct,
230        select_list: cte.query.select_list.clone(),
231        into_table: cte.query.into_table.clone(),
232        into_variables: cte.query.into_variables.clone(),
233        from: cte.query.from.clone(),
234        where_clause: cte.query.where_clause.clone(),
235        group_by: cte.query.group_by.clone(),
236        having: cte.query.having.clone(),
237        order_by: cte.query.order_by.clone(),
238        limit: cte.query.limit.clone(),
239        offset: cte.query.offset.clone(),
240        set_operation: None, // Remove UNION ALL for base term execution
241        values: cte.query.values.clone(),
242    };
243    let recursive_query = &set_op.right;
244
245    // Try static validation first (works for explicit column lists and VALUES)
246    // This provides better SQLite compatibility by catching errors at prepare time
247    // rather than waiting until runtime
248    // Note: For VALUES statements, column count comes from the VALUES rows, not select_list
249    if let (Some(base_count), Some(recursive_count)) =
250        (count_stmt_columns(&base_query), count_stmt_columns(recursive_query))
251    {
252        if base_count != recursive_count {
253            return Err(ExecutorError::UnsupportedFeature(
254                "SELECTs to the left and right of UNION ALL do not have the same number of result columns".to_string()
255            ));
256        }
257    }
258    // Fall back to runtime validation for wildcards (existing code below at line 279-289)
259
260    // Step 1: Execute base term to get initial rows
261    let mut all_rows = executor(&base_query, cte_results)?;
262    let mut working_table = all_rows.clone();
263
264    // Derive schema from base term result
265    let schema = derive_cte_schema(cte, &all_rows)?;
266
267    // Track seen rows for UNION (deduplication)
268    // For UNION ALL, we skip tracking to preserve all rows
269    let mut seen_rows: Option<HashSet<vibesql_storage::RowValues>> = if !set_op.all {
270        let mut seen = HashSet::with_capacity(all_rows.len());
271        for row in &all_rows {
272            seen.insert(row.values.clone());
273        }
274        Some(seen)
275    } else {
276        None
277    };
278
279    // Step 2: Iterative evaluation
280    let mut depth = 0;
281    while !working_table.is_empty() && depth < MAX_RECURSIVE_CTE_ITERATIONS {
282        depth += 1;
283
284        // Make working table available as this CTE for recursive reference
285        let mut recursive_cte_results = cte_results.clone();
286        recursive_cte_results
287            .insert(cte.name.clone(), (schema.clone(), Arc::new(working_table.clone())));
288
289        // Execute recursive term with working table as CTE
290        let new_rows = executor(recursive_query, &recursive_cte_results)?;
291
292        // If no new rows, we're done
293        if new_rows.is_empty() {
294            break;
295        }
296
297        // Validate that recursive term returns same number of columns as base term
298        // This check is done on first iteration to catch schema mismatches early
299        if depth == 1 && !new_rows.is_empty() && !all_rows.is_empty() {
300            let base_col_count = all_rows[0].values.len();
301            let recursive_col_count = new_rows[0].values.len();
302            if base_col_count != recursive_col_count {
303                return Err(ExecutorError::UnsupportedFeature(
304                    "SELECTs to the left and right of UNION ALL do not have the same number of result columns".to_string()
305                ));
306            }
307        }
308
309        // Check memory before adding new rows
310        let estimated_size = super::helpers::estimate_result_size(&new_rows);
311        memory_check(estimated_size)?;
312
313        // Filter out duplicates for UNION (keep all for UNION ALL)
314        let rows_to_add: Vec<vibesql_storage::Row> = if let Some(ref mut seen) = seen_rows {
315            // UNION: only add rows we haven't seen before
316            new_rows.into_iter().filter(|row| seen.insert(row.values.clone())).collect()
317        } else {
318            // UNION ALL: keep all rows
319            new_rows
320        };
321
322        // If no new unique rows (for UNION), we're done
323        if rows_to_add.is_empty() {
324            break;
325        }
326
327        // Add new rows to result
328        all_rows.extend(rows_to_add.clone());
329
330        // Update working table to be the new rows for next iteration
331        working_table = rows_to_add;
332    }
333
334    // Check if we hit max recursion depth
335    if depth >= MAX_RECURSIVE_CTE_ITERATIONS {
336        return Err(ExecutorError::UnsupportedFeature(format!(
337            "Recursive CTE '{}' exceeded maximum iteration limit of {}",
338            cte.name, MAX_RECURSIVE_CTE_ITERATIONS
339        )));
340    }
341
342    Ok(all_rows)
343}
344
345/// Count columns if select list has only explicit expressions (no wildcards)
346///
347/// Returns Some(count) if all select items are explicit expressions.
348/// Returns None if any wildcards are present (requires schema info to count).
349fn count_explicit_columns(select_list: &[vibesql_ast::SelectItem]) -> Option<usize> {
350    let mut count = 0;
351    for item in select_list {
352        match item {
353            vibesql_ast::SelectItem::Expression { .. } => count += 1,
354            // Can't count wildcards statically - need schema info
355            vibesql_ast::SelectItem::Wildcard { .. }
356            | vibesql_ast::SelectItem::QualifiedWildcard { .. } => {
357                return None;
358            }
359        }
360    }
361    Some(count)
362}
363
364/// Count columns in a SELECT statement, considering both select_list and VALUES.
365/// For VALUES statements, the column count comes from the first row of values.
366/// For SELECT statements, the column count comes from the select_list.
367/// Returns None if any wildcards are present (requires schema info to count).
368fn count_stmt_columns(stmt: &vibesql_ast::SelectStmt) -> Option<usize> {
369    // If this is a VALUES statement, count columns from the first VALUES row
370    if let Some(values_rows) = &stmt.values {
371        return values_rows.first().map(|row| row.len());
372    }
373
374    // Otherwise, count columns from the select_list
375    count_explicit_columns(&stmt.select_list)
376}
377
378/// Check if a CTE is self-referential (references itself in UNION/UNION ALL)
379///
380/// SQLite allows recursive CTEs without the RECURSIVE keyword if the CTE
381/// references itself in a set operation. This function detects such cases
382/// by checking if the right side of a UNION/UNION ALL references the CTE name.
383fn is_cte_self_referential(cte: &vibesql_ast::CommonTableExpr) -> bool {
384    // Check if the CTE has a UNION/UNION ALL set operation
385    let set_op = match &cte.query.set_operation {
386        Some(op) if op.op == vibesql_ast::SetOperator::Union => op,
387        _ => return false,
388    };
389
390    // Check if the recursive term references this CTE
391    stmt_references_table(&set_op.right, &cte.name)
392}
393
394/// Check if a SELECT statement references a table name
395fn stmt_references_table(stmt: &vibesql_ast::SelectStmt, table_name: &str) -> bool {
396    // Check FROM clause
397    if let Some(from) = &stmt.from {
398        if from_clause_references_table(from, table_name) {
399            return true;
400        }
401    }
402
403    // Check WHERE clause for subqueries
404    if let Some(where_clause) = &stmt.where_clause {
405        if expr_references_table(where_clause, table_name) {
406            return true;
407        }
408    }
409
410    // Check SELECT list for subqueries
411    for item in &stmt.select_list {
412        if let vibesql_ast::SelectItem::Expression { expr, .. } = item {
413            if expr_references_table(expr, table_name) {
414                return true;
415            }
416        }
417    }
418
419    false
420}
421
422/// Check if a FROM clause references a table name
423fn from_clause_references_table(from: &vibesql_ast::FromClause, table_name: &str) -> bool {
424    match from {
425        vibesql_ast::FromClause::Table { name, .. } => {
426            name.eq_ignore_ascii_case(table_name)
427        }
428        vibesql_ast::FromClause::Subquery { query, .. } => {
429            stmt_references_table(query, table_name)
430        }
431        vibesql_ast::FromClause::Join { left, right, condition, .. } => {
432            from_clause_references_table(left, table_name)
433                || from_clause_references_table(right, table_name)
434                || condition.as_ref().map_or(false, |c| expr_references_table(c, table_name))
435        }
436        vibesql_ast::FromClause::Values { .. } => false,
437    }
438}
439
440/// Check if an expression references a table name (in subqueries)
441fn expr_references_table(expr: &vibesql_ast::Expression, table_name: &str) -> bool {
442    match expr {
443        vibesql_ast::Expression::ScalarSubquery(subquery) => {
444            stmt_references_table(subquery, table_name)
445        }
446        vibesql_ast::Expression::In { subquery, .. } => {
447            stmt_references_table(subquery, table_name)
448        }
449        vibesql_ast::Expression::Exists { subquery, .. } => {
450            stmt_references_table(subquery, table_name)
451        }
452        vibesql_ast::Expression::BinaryOp { left, right, .. } => {
453            expr_references_table(left, table_name) || expr_references_table(right, table_name)
454        }
455        vibesql_ast::Expression::UnaryOp { expr, .. } => {
456            expr_references_table(expr, table_name)
457        }
458        vibesql_ast::Expression::Function { args, .. } => {
459            args.iter().any(|arg| expr_references_table(arg, table_name))
460        }
461        vibesql_ast::Expression::AggregateFunction { args, filter, .. } => {
462            args.iter().any(|arg| expr_references_table(arg, table_name))
463                || filter.as_ref().map_or(false, |f| expr_references_table(f, table_name))
464        }
465        vibesql_ast::Expression::Case { operand, when_clauses, else_result, .. } => {
466            operand.as_ref().map_or(false, |o| expr_references_table(o, table_name))
467                || when_clauses.iter().any(|when| {
468                    when.conditions.iter().any(|c| expr_references_table(c, table_name))
469                        || expr_references_table(&when.result, table_name)
470                })
471                || else_result.as_ref().map_or(false, |e| expr_references_table(e, table_name))
472        }
473        vibesql_ast::Expression::Between { expr, low, high, .. } => {
474            expr_references_table(expr, table_name)
475                || expr_references_table(low, table_name)
476                || expr_references_table(high, table_name)
477        }
478        vibesql_ast::Expression::InList { expr, values, .. } => {
479            expr_references_table(expr, table_name)
480                || values.iter().any(|e| expr_references_table(e, table_name))
481        }
482        vibesql_ast::Expression::Cast { expr, .. }
483        | vibesql_ast::Expression::Collate { expr, .. } => {
484            expr_references_table(expr, table_name)
485        }
486        vibesql_ast::Expression::Conjunction(exprs)
487        | vibesql_ast::Expression::Disjunction(exprs) => {
488            exprs.iter().any(|e| expr_references_table(e, table_name))
489        }
490        vibesql_ast::Expression::QuantifiedComparison { expr, subquery, .. } => {
491            expr_references_table(expr, table_name) || stmt_references_table(subquery, table_name)
492        }
493        _ => false,
494    }
495}
496
497/// Infer data type from a SQL value
498pub(super) fn infer_type_from_value(value: &vibesql_types::SqlValue) -> vibesql_types::DataType {
499    match value {
500        vibesql_types::SqlValue::Null => vibesql_types::DataType::Varchar { max_length: Some(255) }, /* default */
501        vibesql_types::SqlValue::Integer(_) => vibesql_types::DataType::Integer,
502        vibesql_types::SqlValue::Varchar(_) => {
503            vibesql_types::DataType::Varchar { max_length: Some(255) }
504        }
505        vibesql_types::SqlValue::Character(_) => vibesql_types::DataType::Character { length: 1 },
506        vibesql_types::SqlValue::Boolean(_) => vibesql_types::DataType::Boolean,
507        vibesql_types::SqlValue::Float(_) => vibesql_types::DataType::Float { precision: 53 },
508        vibesql_types::SqlValue::Double(_) => vibesql_types::DataType::DoublePrecision,
509        vibesql_types::SqlValue::Numeric(_) => {
510            vibesql_types::DataType::Numeric { precision: 10, scale: 2 }
511        }
512        vibesql_types::SqlValue::Real(_) => vibesql_types::DataType::Real,
513        vibesql_types::SqlValue::Smallint(_) => vibesql_types::DataType::Smallint,
514        vibesql_types::SqlValue::Bigint(_) => vibesql_types::DataType::Bigint,
515        vibesql_types::SqlValue::Unsigned(_) => vibesql_types::DataType::Unsigned,
516        vibesql_types::SqlValue::Date(_) => vibesql_types::DataType::Date,
517        vibesql_types::SqlValue::Time(_) => vibesql_types::DataType::Time { with_timezone: false },
518        vibesql_types::SqlValue::Timestamp(_) => {
519            vibesql_types::DataType::Timestamp { with_timezone: false }
520        }
521        vibesql_types::SqlValue::Interval(_) => {
522            // For now, return a simple INTERVAL type (can be enhanced to detect field types)
523            vibesql_types::DataType::Interval {
524                start_field: vibesql_types::IntervalField::Day,
525                end_field: None,
526            }
527        }
528        vibesql_types::SqlValue::Vector(v) => {
529            vibesql_types::DataType::Vector { dimensions: v.len() as u32 }
530        }
531        vibesql_types::SqlValue::Blob(_) => vibesql_types::DataType::BinaryLargeObject,
532    }
533}