1use std::{
4 collections::{HashMap, HashSet},
5 sync::Arc,
6};
7
8use crate::errors::ExecutorError;
9
10pub type CteResult = (vibesql_catalog::TableSchema, Arc<Vec<vibesql_storage::Row>>);
18
19pub fn execute_ctes<F>(
23 ctes: &[vibesql_ast::CommonTableExpr],
24 executor: F,
25) -> Result<HashMap<String, CteResult>, ExecutorError>
26where
27 F: Fn(
28 &vibesql_ast::SelectStmt,
29 &HashMap<String, CteResult>,
30 ) -> Result<Vec<vibesql_storage::Row>, ExecutorError>,
31{
32 execute_ctes_with_memory_check(ctes, executor, |_| Ok(()))
34}
35
36pub(super) fn execute_ctes_with_memory_check<F, M>(
42 ctes: &[vibesql_ast::CommonTableExpr],
43 executor: F,
44 memory_check: M,
45) -> Result<HashMap<String, CteResult>, ExecutorError>
46where
47 F: Fn(
48 &vibesql_ast::SelectStmt,
49 &HashMap<String, CteResult>,
50 ) -> Result<Vec<vibesql_storage::Row>, ExecutorError>,
51 M: Fn(usize) -> Result<(), ExecutorError>,
52{
53 let mut cte_results = HashMap::new();
54
55 for cte in ctes {
58 let is_recursive = cte.recursive || is_cte_self_referential(cte);
62 let rows = if is_recursive {
63 execute_recursive_cte(cte, &cte_results, &executor, &memory_check)?
65 } else {
66 executor(&cte.query, &cte_results)?
68 };
69
70 let estimated_size = super::helpers::estimate_result_size(&rows);
72 memory_check(estimated_size)?;
73
74 let schema = derive_cte_schema(cte, &rows)?;
76
77 cte_results.insert(cte.name.clone(), (schema, Arc::new(rows)));
79 }
80
81 Ok(cte_results)
82}
83
84pub(super) fn derive_cte_schema(
86 cte: &vibesql_ast::CommonTableExpr,
87 rows: &[vibesql_storage::Row],
88) -> Result<vibesql_catalog::TableSchema, ExecutorError> {
89 if let Some(column_names) = &cte.columns {
91 if let Some(first_row) = rows.first() {
93 if first_row.values.len() != column_names.len() {
94 return Err(ExecutorError::UnsupportedFeature(format!(
95 "CTE column count mismatch: specified {} columns but query returned {}",
96 column_names.len(),
97 first_row.values.len()
98 )));
99 }
100
101 let columns = column_names
102 .iter()
103 .zip(&first_row.values)
104 .map(|(name, value)| {
105 let data_type = infer_type_from_value(value);
106 vibesql_catalog::ColumnSchema::new(name.clone(), data_type, true)
107 })
110 .collect();
111
112 Ok(vibesql_catalog::TableSchema::new(cte.name.clone(), columns))
113 } else {
114 let columns = column_names
116 .iter()
117 .map(|name| {
118 vibesql_catalog::ColumnSchema::new(
119 name.clone(),
120 vibesql_types::DataType::Varchar { max_length: Some(255) },
121 true,
122 )
123 })
124 .collect();
125
126 Ok(vibesql_catalog::TableSchema::new(cte.name.clone(), columns))
127 }
128 } else {
129 let columns = cte
132 .query
133 .select_list
134 .iter()
135 .enumerate()
136 .map(|(i, item)| {
137 let data_type = if let Some(first_row) = rows.first() {
139 infer_type_from_value(&first_row.values[i])
140 } else {
141 vibesql_types::DataType::Varchar { max_length: Some(255) }
143 };
144
145 let col_name = match item {
147 vibesql_ast::SelectItem::Wildcard { .. }
148 | vibesql_ast::SelectItem::QualifiedWildcard { .. } => format!("col{}", i),
149 vibesql_ast::SelectItem::Expression { expr, alias, .. } => {
150 if let Some(a) = alias {
151 a.clone()
152 } else {
153 match expr {
155 vibesql_ast::Expression::ColumnRef(col_id) => {
156 col_id.column_canonical().to_string()
157 }
158 _ => format!("col{}", i),
159 }
160 }
161 }
162 };
163
164 vibesql_catalog::ColumnSchema::new(col_name, data_type, true) })
166 .collect();
167
168 Ok(vibesql_catalog::TableSchema::new(cte.name.clone(), columns))
169 }
170}
171
172fn execute_recursive_cte<F, M>(
192 cte: &vibesql_ast::CommonTableExpr,
193 cte_results: &HashMap<String, CteResult>,
194 executor: &F,
195 memory_check: &M,
196) -> Result<Vec<vibesql_storage::Row>, ExecutorError>
197where
198 F: Fn(
199 &vibesql_ast::SelectStmt,
200 &HashMap<String, CteResult>,
201 ) -> Result<Vec<vibesql_storage::Row>, ExecutorError>,
202 M: Fn(usize) -> Result<(), ExecutorError>,
203{
204 use crate::limits::MAX_RECURSIVE_CTE_ITERATIONS;
205
206 let set_op = cte.query.set_operation.as_ref().ok_or_else(|| {
208 ExecutorError::UnsupportedFeature(format!(
209 "Recursive CTE '{}' must use UNION ALL",
210 cte.name
211 ))
212 })?;
213
214 if set_op.op != vibesql_ast::SetOperator::Union {
215 return Err(ExecutorError::UnsupportedFeature(format!(
216 "Recursive CTE '{}' must use UNION or UNION ALL (not INTERSECT or EXCEPT)",
217 cte.name
218 )));
219 }
220
221 let base_query = vibesql_ast::SelectStmt {
228 with_clause: cte.query.with_clause.clone(),
229 distinct: cte.query.distinct,
230 select_list: cte.query.select_list.clone(),
231 into_table: cte.query.into_table.clone(),
232 into_variables: cte.query.into_variables.clone(),
233 from: cte.query.from.clone(),
234 where_clause: cte.query.where_clause.clone(),
235 group_by: cte.query.group_by.clone(),
236 having: cte.query.having.clone(),
237 order_by: cte.query.order_by.clone(),
238 limit: cte.query.limit.clone(),
239 offset: cte.query.offset.clone(),
240 set_operation: None, values: cte.query.values.clone(),
242 };
243 let recursive_query = &set_op.right;
244
245 if let (Some(base_count), Some(recursive_count)) =
250 (count_stmt_columns(&base_query), count_stmt_columns(recursive_query))
251 {
252 if base_count != recursive_count {
253 return Err(ExecutorError::UnsupportedFeature(
254 "SELECTs to the left and right of UNION ALL do not have the same number of result columns".to_string()
255 ));
256 }
257 }
258 let mut all_rows = executor(&base_query, cte_results)?;
262 let mut working_table = all_rows.clone();
263
264 let schema = derive_cte_schema(cte, &all_rows)?;
266
267 let mut seen_rows: Option<HashSet<vibesql_storage::RowValues>> = if !set_op.all {
270 let mut seen = HashSet::with_capacity(all_rows.len());
271 for row in &all_rows {
272 seen.insert(row.values.clone());
273 }
274 Some(seen)
275 } else {
276 None
277 };
278
279 let mut depth = 0;
281 while !working_table.is_empty() && depth < MAX_RECURSIVE_CTE_ITERATIONS {
282 depth += 1;
283
284 let mut recursive_cte_results = cte_results.clone();
286 recursive_cte_results
287 .insert(cte.name.clone(), (schema.clone(), Arc::new(working_table.clone())));
288
289 let new_rows = executor(recursive_query, &recursive_cte_results)?;
291
292 if new_rows.is_empty() {
294 break;
295 }
296
297 if depth == 1 && !new_rows.is_empty() && !all_rows.is_empty() {
300 let base_col_count = all_rows[0].values.len();
301 let recursive_col_count = new_rows[0].values.len();
302 if base_col_count != recursive_col_count {
303 return Err(ExecutorError::UnsupportedFeature(
304 "SELECTs to the left and right of UNION ALL do not have the same number of result columns".to_string()
305 ));
306 }
307 }
308
309 let estimated_size = super::helpers::estimate_result_size(&new_rows);
311 memory_check(estimated_size)?;
312
313 let rows_to_add: Vec<vibesql_storage::Row> = if let Some(ref mut seen) = seen_rows {
315 new_rows.into_iter().filter(|row| seen.insert(row.values.clone())).collect()
317 } else {
318 new_rows
320 };
321
322 if rows_to_add.is_empty() {
324 break;
325 }
326
327 all_rows.extend(rows_to_add.clone());
329
330 working_table = rows_to_add;
332 }
333
334 if depth >= MAX_RECURSIVE_CTE_ITERATIONS {
336 return Err(ExecutorError::UnsupportedFeature(format!(
337 "Recursive CTE '{}' exceeded maximum iteration limit of {}",
338 cte.name, MAX_RECURSIVE_CTE_ITERATIONS
339 )));
340 }
341
342 Ok(all_rows)
343}
344
345fn count_explicit_columns(select_list: &[vibesql_ast::SelectItem]) -> Option<usize> {
350 let mut count = 0;
351 for item in select_list {
352 match item {
353 vibesql_ast::SelectItem::Expression { .. } => count += 1,
354 vibesql_ast::SelectItem::Wildcard { .. }
356 | vibesql_ast::SelectItem::QualifiedWildcard { .. } => {
357 return None;
358 }
359 }
360 }
361 Some(count)
362}
363
364fn count_stmt_columns(stmt: &vibesql_ast::SelectStmt) -> Option<usize> {
369 if let Some(values_rows) = &stmt.values {
371 return values_rows.first().map(|row| row.len());
372 }
373
374 count_explicit_columns(&stmt.select_list)
376}
377
378fn is_cte_self_referential(cte: &vibesql_ast::CommonTableExpr) -> bool {
384 let set_op = match &cte.query.set_operation {
386 Some(op) if op.op == vibesql_ast::SetOperator::Union => op,
387 _ => return false,
388 };
389
390 stmt_references_table(&set_op.right, &cte.name)
392}
393
394fn stmt_references_table(stmt: &vibesql_ast::SelectStmt, table_name: &str) -> bool {
396 if let Some(from) = &stmt.from {
398 if from_clause_references_table(from, table_name) {
399 return true;
400 }
401 }
402
403 if let Some(where_clause) = &stmt.where_clause {
405 if expr_references_table(where_clause, table_name) {
406 return true;
407 }
408 }
409
410 for item in &stmt.select_list {
412 if let vibesql_ast::SelectItem::Expression { expr, .. } = item {
413 if expr_references_table(expr, table_name) {
414 return true;
415 }
416 }
417 }
418
419 false
420}
421
422fn from_clause_references_table(from: &vibesql_ast::FromClause, table_name: &str) -> bool {
424 match from {
425 vibesql_ast::FromClause::Table { name, .. } => {
426 name.eq_ignore_ascii_case(table_name)
427 }
428 vibesql_ast::FromClause::Subquery { query, .. } => {
429 stmt_references_table(query, table_name)
430 }
431 vibesql_ast::FromClause::Join { left, right, condition, .. } => {
432 from_clause_references_table(left, table_name)
433 || from_clause_references_table(right, table_name)
434 || condition.as_ref().map_or(false, |c| expr_references_table(c, table_name))
435 }
436 vibesql_ast::FromClause::Values { .. } => false,
437 }
438}
439
440fn expr_references_table(expr: &vibesql_ast::Expression, table_name: &str) -> bool {
442 match expr {
443 vibesql_ast::Expression::ScalarSubquery(subquery) => {
444 stmt_references_table(subquery, table_name)
445 }
446 vibesql_ast::Expression::In { subquery, .. } => {
447 stmt_references_table(subquery, table_name)
448 }
449 vibesql_ast::Expression::Exists { subquery, .. } => {
450 stmt_references_table(subquery, table_name)
451 }
452 vibesql_ast::Expression::BinaryOp { left, right, .. } => {
453 expr_references_table(left, table_name) || expr_references_table(right, table_name)
454 }
455 vibesql_ast::Expression::UnaryOp { expr, .. } => {
456 expr_references_table(expr, table_name)
457 }
458 vibesql_ast::Expression::Function { args, .. } => {
459 args.iter().any(|arg| expr_references_table(arg, table_name))
460 }
461 vibesql_ast::Expression::AggregateFunction { args, filter, .. } => {
462 args.iter().any(|arg| expr_references_table(arg, table_name))
463 || filter.as_ref().map_or(false, |f| expr_references_table(f, table_name))
464 }
465 vibesql_ast::Expression::Case { operand, when_clauses, else_result, .. } => {
466 operand.as_ref().map_or(false, |o| expr_references_table(o, table_name))
467 || when_clauses.iter().any(|when| {
468 when.conditions.iter().any(|c| expr_references_table(c, table_name))
469 || expr_references_table(&when.result, table_name)
470 })
471 || else_result.as_ref().map_or(false, |e| expr_references_table(e, table_name))
472 }
473 vibesql_ast::Expression::Between { expr, low, high, .. } => {
474 expr_references_table(expr, table_name)
475 || expr_references_table(low, table_name)
476 || expr_references_table(high, table_name)
477 }
478 vibesql_ast::Expression::InList { expr, values, .. } => {
479 expr_references_table(expr, table_name)
480 || values.iter().any(|e| expr_references_table(e, table_name))
481 }
482 vibesql_ast::Expression::Cast { expr, .. }
483 | vibesql_ast::Expression::Collate { expr, .. } => {
484 expr_references_table(expr, table_name)
485 }
486 vibesql_ast::Expression::Conjunction(exprs)
487 | vibesql_ast::Expression::Disjunction(exprs) => {
488 exprs.iter().any(|e| expr_references_table(e, table_name))
489 }
490 vibesql_ast::Expression::QuantifiedComparison { expr, subquery, .. } => {
491 expr_references_table(expr, table_name) || stmt_references_table(subquery, table_name)
492 }
493 _ => false,
494 }
495}
496
497pub(super) fn infer_type_from_value(value: &vibesql_types::SqlValue) -> vibesql_types::DataType {
499 match value {
500 vibesql_types::SqlValue::Null => vibesql_types::DataType::Varchar { max_length: Some(255) }, vibesql_types::SqlValue::Integer(_) => vibesql_types::DataType::Integer,
502 vibesql_types::SqlValue::Varchar(_) => {
503 vibesql_types::DataType::Varchar { max_length: Some(255) }
504 }
505 vibesql_types::SqlValue::Character(_) => vibesql_types::DataType::Character { length: 1 },
506 vibesql_types::SqlValue::Boolean(_) => vibesql_types::DataType::Boolean,
507 vibesql_types::SqlValue::Float(_) => vibesql_types::DataType::Float { precision: 53 },
508 vibesql_types::SqlValue::Double(_) => vibesql_types::DataType::DoublePrecision,
509 vibesql_types::SqlValue::Numeric(_) => {
510 vibesql_types::DataType::Numeric { precision: 10, scale: 2 }
511 }
512 vibesql_types::SqlValue::Real(_) => vibesql_types::DataType::Real,
513 vibesql_types::SqlValue::Smallint(_) => vibesql_types::DataType::Smallint,
514 vibesql_types::SqlValue::Bigint(_) => vibesql_types::DataType::Bigint,
515 vibesql_types::SqlValue::Unsigned(_) => vibesql_types::DataType::Unsigned,
516 vibesql_types::SqlValue::Date(_) => vibesql_types::DataType::Date,
517 vibesql_types::SqlValue::Time(_) => vibesql_types::DataType::Time { with_timezone: false },
518 vibesql_types::SqlValue::Timestamp(_) => {
519 vibesql_types::DataType::Timestamp { with_timezone: false }
520 }
521 vibesql_types::SqlValue::Interval(_) => {
522 vibesql_types::DataType::Interval {
524 start_field: vibesql_types::IntervalField::Day,
525 end_field: None,
526 }
527 }
528 vibesql_types::SqlValue::Vector(v) => {
529 vibesql_types::DataType::Vector { dimensions: v.len() as u32 }
530 }
531 vibesql_types::SqlValue::Blob(_) => vibesql_types::DataType::BinaryLargeObject,
532 }
533}