vibesql_executor/select/
cte.rs

1//! Common Table Expression (CTE) handling for SELECT queries
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use crate::errors::ExecutorError;
7
8/// CTE result: (schema, shared rows)
9///
10/// Uses `Arc<Vec<Row>>` to enable O(1) cloning when CTEs are:
11/// - Propagated from outer queries to subqueries
12/// - Referenced multiple times without filtering
13///
14/// This avoids deep-cloning all rows on every CTE reference.
15pub type CteResult = (vibesql_catalog::TableSchema, Arc<Vec<vibesql_storage::Row>>);
16
17/// Execute all CTEs and return their results
18///
19/// CTEs are executed in order, allowing later CTEs to reference earlier ones.
20pub(super) fn execute_ctes<F>(
21    ctes: &[vibesql_ast::CommonTableExpr],
22    executor: F,
23) -> Result<HashMap<String, CteResult>, ExecutorError>
24where
25    F: Fn(
26        &vibesql_ast::SelectStmt,
27        &HashMap<String, CteResult>,
28    ) -> Result<Vec<vibesql_storage::Row>, ExecutorError>,
29{
30    // Use the memory-tracking version with a no-op memory check
31    execute_ctes_with_memory_check(ctes, executor, |_| Ok(()))
32}
33
34/// Execute all CTEs with memory tracking
35///
36/// CTEs are executed in order, allowing later CTEs to reference earlier ones.
37/// After each CTE is materialized, the memory_check callback is called with
38/// the estimated size of the CTE result to enforce memory limits.
39pub(super) fn execute_ctes_with_memory_check<F, M>(
40    ctes: &[vibesql_ast::CommonTableExpr],
41    executor: F,
42    memory_check: M,
43) -> Result<HashMap<String, CteResult>, ExecutorError>
44where
45    F: Fn(
46        &vibesql_ast::SelectStmt,
47        &HashMap<String, CteResult>,
48    ) -> Result<Vec<vibesql_storage::Row>, ExecutorError>,
49    M: Fn(usize) -> Result<(), ExecutorError>,
50{
51    let mut cte_results = HashMap::new();
52
53    // Execute each CTE in order
54    // CTEs can reference previously defined CTEs
55    for cte in ctes {
56        // Execute the CTE query with accumulated CTE results so far
57        // This allows later CTEs to reference earlier ones
58        let rows = executor(&cte.query, &cte_results)?;
59
60        // Track memory for this CTE result before storing
61        let estimated_size = super::helpers::estimate_result_size(&rows);
62        memory_check(estimated_size)?;
63
64        //  Determine the schema for this CTE
65        let schema = derive_cte_schema(cte, &rows)?;
66
67        // Store the CTE result wrapped in Arc for efficient sharing
68        cte_results.insert(cte.name.clone(), (schema, Arc::new(rows)));
69    }
70
71    Ok(cte_results)
72}
73
74/// Derive the schema for a CTE from its query and results
75pub(super) fn derive_cte_schema(
76    cte: &vibesql_ast::CommonTableExpr,
77    rows: &[vibesql_storage::Row],
78) -> Result<vibesql_catalog::TableSchema, ExecutorError> {
79    // If column names are explicitly specified, use those
80    if let Some(column_names) = &cte.columns {
81        // Get data types from first row (if available)
82        if let Some(first_row) = rows.first() {
83            if first_row.values.len() != column_names.len() {
84                return Err(ExecutorError::UnsupportedFeature(format!(
85                    "CTE column count mismatch: specified {} columns but query returned {}",
86                    column_names.len(),
87                    first_row.values.len()
88                )));
89            }
90
91            let columns = column_names
92                .iter()
93                .zip(&first_row.values)
94                .map(|(name, value)| {
95                    let data_type = infer_type_from_value(value);
96                    vibesql_catalog::ColumnSchema::new(name.clone(), data_type, true) // nullable for
97                                                                              // simplicity
98                })
99                .collect();
100
101            Ok(vibesql_catalog::TableSchema::new(cte.name.clone(), columns))
102        } else {
103            // Empty result set - create schema with VARCHAR columns
104            let columns = column_names
105                .iter()
106                .map(|name| {
107                    vibesql_catalog::ColumnSchema::new(
108                        name.clone(),
109                        vibesql_types::DataType::Varchar { max_length: Some(255) },
110                        true,
111                    )
112                })
113                .collect();
114
115            Ok(vibesql_catalog::TableSchema::new(cte.name.clone(), columns))
116        }
117    } else {
118        // No explicit column names - infer from query SELECT list
119        // Extract column names from SELECT items
120        let columns = cte
121            .query
122            .select_list
123            .iter()
124            .enumerate()
125            .map(|(i, item)| {
126                // Infer data type from first row if available, otherwise use default
127                let data_type = if let Some(first_row) = rows.first() {
128                    infer_type_from_value(&first_row.values[i])
129                } else {
130                    // No rows - use default type (VARCHAR)
131                    vibesql_types::DataType::Varchar { max_length: Some(255) }
132                };
133
134                // Extract column name from SELECT item
135                let col_name = match item {
136                    vibesql_ast::SelectItem::Wildcard { .. }
137                    | vibesql_ast::SelectItem::QualifiedWildcard { .. } => format!("col{}", i),
138                    vibesql_ast::SelectItem::Expression { expr, alias } => {
139                        if let Some(a) = alias {
140                            a.clone()
141                        } else {
142                            // Try to extract name from expression
143                            match expr {
144                                vibesql_ast::Expression::ColumnRef { table: _, column } => {
145                                    column.clone()
146                                }
147                                _ => format!("col{}", i),
148                            }
149                        }
150                    }
151                };
152
153                vibesql_catalog::ColumnSchema::new(col_name, data_type, true) // nullable
154            })
155            .collect();
156
157        Ok(vibesql_catalog::TableSchema::new(cte.name.clone(), columns))
158    }
159}
160
161/// Infer data type from a SQL value
162pub(super) fn infer_type_from_value(value: &vibesql_types::SqlValue) -> vibesql_types::DataType {
163    match value {
164        vibesql_types::SqlValue::Null => vibesql_types::DataType::Varchar { max_length: Some(255) }, // default
165        vibesql_types::SqlValue::Integer(_) => vibesql_types::DataType::Integer,
166        vibesql_types::SqlValue::Varchar(_) => vibesql_types::DataType::Varchar { max_length: Some(255) },
167        vibesql_types::SqlValue::Character(_) => vibesql_types::DataType::Character { length: 1 },
168        vibesql_types::SqlValue::Boolean(_) => vibesql_types::DataType::Boolean,
169        vibesql_types::SqlValue::Float(_) => vibesql_types::DataType::Float { precision: 53 },
170        vibesql_types::SqlValue::Double(_) => vibesql_types::DataType::DoublePrecision,
171        vibesql_types::SqlValue::Numeric(_) => vibesql_types::DataType::Numeric { precision: 10, scale: 2 },
172        vibesql_types::SqlValue::Real(_) => vibesql_types::DataType::Real,
173        vibesql_types::SqlValue::Smallint(_) => vibesql_types::DataType::Smallint,
174        vibesql_types::SqlValue::Bigint(_) => vibesql_types::DataType::Bigint,
175        vibesql_types::SqlValue::Unsigned(_) => vibesql_types::DataType::Unsigned,
176        vibesql_types::SqlValue::Date(_) => vibesql_types::DataType::Date,
177        vibesql_types::SqlValue::Time(_) => vibesql_types::DataType::Time { with_timezone: false },
178        vibesql_types::SqlValue::Timestamp(_) => vibesql_types::DataType::Timestamp { with_timezone: false },
179        vibesql_types::SqlValue::Interval(_) => {
180            // For now, return a simple INTERVAL type (can be enhanced to detect field types)
181            vibesql_types::DataType::Interval {
182                start_field: vibesql_types::IntervalField::Day,
183                end_field: None,
184            }
185        }
186        vibesql_types::SqlValue::Vector(v) => {
187            vibesql_types::DataType::Vector { dimensions: v.len() as u32 }
188        }
189    }
190}