vibesql_executor/select/
cte.rs

1//! Common Table Expression (CTE) handling for SELECT queries
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use crate::errors::ExecutorError;
7
8/// CTE result: (schema, shared rows)
9///
10/// Uses `Arc<Vec<Row>>` to enable O(1) cloning when CTEs are:
11/// - Propagated from outer queries to subqueries
12/// - Referenced multiple times without filtering
13///
14/// This avoids deep-cloning all rows on every CTE reference.
15pub type CteResult = (vibesql_catalog::TableSchema, Arc<Vec<vibesql_storage::Row>>);
16
17/// Execute all CTEs and return their results
18///
19/// CTEs are executed in order, allowing later CTEs to reference earlier ones.
20pub(super) fn execute_ctes<F>(
21    ctes: &[vibesql_ast::CommonTableExpr],
22    executor: F,
23) -> Result<HashMap<String, CteResult>, ExecutorError>
24where
25    F: Fn(
26        &vibesql_ast::SelectStmt,
27        &HashMap<String, CteResult>,
28    ) -> Result<Vec<vibesql_storage::Row>, ExecutorError>,
29{
30    // Use the memory-tracking version with a no-op memory check
31    execute_ctes_with_memory_check(ctes, executor, |_| Ok(()))
32}
33
34/// Execute all CTEs with memory tracking
35///
36/// CTEs are executed in order, allowing later CTEs to reference earlier ones.
37/// After each CTE is materialized, the memory_check callback is called with
38/// the estimated size of the CTE result to enforce memory limits.
39pub(super) fn execute_ctes_with_memory_check<F, M>(
40    ctes: &[vibesql_ast::CommonTableExpr],
41    executor: F,
42    memory_check: M,
43) -> Result<HashMap<String, CteResult>, ExecutorError>
44where
45    F: Fn(
46        &vibesql_ast::SelectStmt,
47        &HashMap<String, CteResult>,
48    ) -> Result<Vec<vibesql_storage::Row>, ExecutorError>,
49    M: Fn(usize) -> Result<(), ExecutorError>,
50{
51    let mut cte_results = HashMap::new();
52
53    // Execute each CTE in order
54    // CTEs can reference previously defined CTEs
55    for cte in ctes {
56        // Execute the CTE query with accumulated CTE results so far
57        // This allows later CTEs to reference earlier ones
58        let rows = executor(&cte.query, &cte_results)?;
59
60        // Track memory for this CTE result before storing
61        let estimated_size = super::helpers::estimate_result_size(&rows);
62        memory_check(estimated_size)?;
63
64        //  Determine the schema for this CTE
65        let schema = derive_cte_schema(cte, &rows)?;
66
67        // Store the CTE result wrapped in Arc for efficient sharing
68        cte_results.insert(cte.name.clone(), (schema, Arc::new(rows)));
69    }
70
71    Ok(cte_results)
72}
73
74/// Derive the schema for a CTE from its query and results
75pub(super) fn derive_cte_schema(
76    cte: &vibesql_ast::CommonTableExpr,
77    rows: &[vibesql_storage::Row],
78) -> Result<vibesql_catalog::TableSchema, ExecutorError> {
79    // If column names are explicitly specified, use those
80    if let Some(column_names) = &cte.columns {
81        // Get data types from first row (if available)
82        if let Some(first_row) = rows.first() {
83            if first_row.values.len() != column_names.len() {
84                return Err(ExecutorError::UnsupportedFeature(format!(
85                    "CTE column count mismatch: specified {} columns but query returned {}",
86                    column_names.len(),
87                    first_row.values.len()
88                )));
89            }
90
91            let columns = column_names
92                .iter()
93                .zip(&first_row.values)
94                .map(|(name, value)| {
95                    let data_type = infer_type_from_value(value);
96                    vibesql_catalog::ColumnSchema::new(name.clone(), data_type, true)
97                    // nullable for
98                    // simplicity
99                })
100                .collect();
101
102            Ok(vibesql_catalog::TableSchema::new(cte.name.clone(), columns))
103        } else {
104            // Empty result set - create schema with VARCHAR columns
105            let columns = column_names
106                .iter()
107                .map(|name| {
108                    vibesql_catalog::ColumnSchema::new(
109                        name.clone(),
110                        vibesql_types::DataType::Varchar { max_length: Some(255) },
111                        true,
112                    )
113                })
114                .collect();
115
116            Ok(vibesql_catalog::TableSchema::new(cte.name.clone(), columns))
117        }
118    } else {
119        // No explicit column names - infer from query SELECT list
120        // Extract column names from SELECT items
121        let columns = cte
122            .query
123            .select_list
124            .iter()
125            .enumerate()
126            .map(|(i, item)| {
127                // Infer data type from first row if available, otherwise use default
128                let data_type = if let Some(first_row) = rows.first() {
129                    infer_type_from_value(&first_row.values[i])
130                } else {
131                    // No rows - use default type (VARCHAR)
132                    vibesql_types::DataType::Varchar { max_length: Some(255) }
133                };
134
135                // Extract column name from SELECT item
136                let col_name = match item {
137                    vibesql_ast::SelectItem::Wildcard { .. }
138                    | vibesql_ast::SelectItem::QualifiedWildcard { .. } => format!("col{}", i),
139                    vibesql_ast::SelectItem::Expression { expr, alias } => {
140                        if let Some(a) = alias {
141                            a.clone()
142                        } else {
143                            // Try to extract name from expression
144                            match expr {
145                                vibesql_ast::Expression::ColumnRef { table: _, column } => {
146                                    column.clone()
147                                }
148                                _ => format!("col{}", i),
149                            }
150                        }
151                    }
152                };
153
154                vibesql_catalog::ColumnSchema::new(col_name, data_type, true) // nullable
155            })
156            .collect();
157
158        Ok(vibesql_catalog::TableSchema::new(cte.name.clone(), columns))
159    }
160}
161
162/// Infer data type from a SQL value
163pub(super) fn infer_type_from_value(value: &vibesql_types::SqlValue) -> vibesql_types::DataType {
164    match value {
165        vibesql_types::SqlValue::Null => vibesql_types::DataType::Varchar { max_length: Some(255) }, // default
166        vibesql_types::SqlValue::Integer(_) => vibesql_types::DataType::Integer,
167        vibesql_types::SqlValue::Varchar(_) => {
168            vibesql_types::DataType::Varchar { max_length: Some(255) }
169        }
170        vibesql_types::SqlValue::Character(_) => vibesql_types::DataType::Character { length: 1 },
171        vibesql_types::SqlValue::Boolean(_) => vibesql_types::DataType::Boolean,
172        vibesql_types::SqlValue::Float(_) => vibesql_types::DataType::Float { precision: 53 },
173        vibesql_types::SqlValue::Double(_) => vibesql_types::DataType::DoublePrecision,
174        vibesql_types::SqlValue::Numeric(_) => {
175            vibesql_types::DataType::Numeric { precision: 10, scale: 2 }
176        }
177        vibesql_types::SqlValue::Real(_) => vibesql_types::DataType::Real,
178        vibesql_types::SqlValue::Smallint(_) => vibesql_types::DataType::Smallint,
179        vibesql_types::SqlValue::Bigint(_) => vibesql_types::DataType::Bigint,
180        vibesql_types::SqlValue::Unsigned(_) => vibesql_types::DataType::Unsigned,
181        vibesql_types::SqlValue::Date(_) => vibesql_types::DataType::Date,
182        vibesql_types::SqlValue::Time(_) => vibesql_types::DataType::Time { with_timezone: false },
183        vibesql_types::SqlValue::Timestamp(_) => {
184            vibesql_types::DataType::Timestamp { with_timezone: false }
185        }
186        vibesql_types::SqlValue::Interval(_) => {
187            // For now, return a simple INTERVAL type (can be enhanced to detect field types)
188            vibesql_types::DataType::Interval {
189                start_field: vibesql_types::IntervalField::Day,
190                end_field: None,
191            }
192        }
193        vibesql_types::SqlValue::Vector(v) => {
194            vibesql_types::DataType::Vector { dimensions: v.len() as u32 }
195        }
196    }
197}