Skip to main content

sql_cli/data/
subquery_executor.rs

1// Subquery execution handler
2// Walks the AST to evaluate subqueries and replace them with their results
3
4use crate::data::data_view::DataView;
5use crate::data::datatable::{DataTable, DataValue};
6use crate::data::query_engine::QueryEngine;
7use crate::sql::parser::ast::{Condition, SelectItem, SelectStatement, SqlExpression, WhereClause};
8use anyhow::{anyhow, Result};
9use std::collections::{HashMap, HashSet};
10use std::sync::Arc;
11use tracing::{debug, info};
12
13/// Convert a DataValue back into a SqlExpression literal (used when materialising
14/// subquery results into the AST for the evaluator to handle).
15fn datavalue_to_literal(v: &DataValue) -> SqlExpression {
16    match v {
17        DataValue::Null => SqlExpression::Null,
18        DataValue::Integer(i) => SqlExpression::NumberLiteral(i.to_string()),
19        DataValue::Float(f) => SqlExpression::NumberLiteral(f.to_string()),
20        DataValue::String(s) => SqlExpression::StringLiteral(s.clone()),
21        DataValue::InternedString(s) => SqlExpression::StringLiteral(s.to_string()),
22        DataValue::Boolean(b) => SqlExpression::BooleanLiteral(*b),
23        DataValue::DateTime(dt) => SqlExpression::StringLiteral(dt.clone()),
24        DataValue::Vector(vec) => {
25            let components: Vec<String> = vec.iter().map(|f| f.to_string()).collect();
26            SqlExpression::StringLiteral(format!("[{}]", components.join(",")))
27        }
28    }
29}
30
31/// Build an expression equivalent to `(e1, e2, ...) IN (v_row1, v_row2, ...)`
32/// expanded as OR of AND equality checks. For an empty subquery result the
33/// expression evaluates to FALSE (true for NOT IN).
34///
35/// Example: `(a, b) IN (SELECT x, y ...)` with rows [(1,2), (3,4)] becomes
36///     (a = 1 AND b = 2) OR (a = 3 AND b = 4)
37pub(crate) fn build_tuple_in_expression(
38    exprs: &[SqlExpression],
39    rows: &[Vec<DataValue>],
40    negate: bool,
41) -> SqlExpression {
42    if rows.is_empty() {
43        // Empty subquery: IN is always false, NOT IN is always true
44        return SqlExpression::BooleanLiteral(negate);
45    }
46
47    // Build OR of row-matches. Each row-match is AND of column equalities.
48    let mut or_expr: Option<SqlExpression> = None;
49    for row in rows {
50        // Build AND of equalities for this row
51        let mut and_expr: Option<SqlExpression> = None;
52        for (i, value) in row.iter().enumerate() {
53            let eq = SqlExpression::BinaryOp {
54                left: Box::new(exprs[i].clone()),
55                op: "=".to_string(),
56                right: Box::new(datavalue_to_literal(value)),
57            };
58            and_expr = Some(match and_expr {
59                None => eq,
60                Some(prev) => SqlExpression::BinaryOp {
61                    left: Box::new(prev),
62                    op: "AND".to_string(),
63                    right: Box::new(eq),
64                },
65            });
66        }
67        let row_match = and_expr.expect("row had zero columns — should not happen");
68        or_expr = Some(match or_expr {
69            None => row_match,
70            Some(prev) => SqlExpression::BinaryOp {
71                left: Box::new(prev),
72                op: "OR".to_string(),
73                right: Box::new(row_match),
74            },
75        });
76    }
77
78    let matches = or_expr.expect("rows was non-empty");
79    if negate {
80        SqlExpression::Not {
81            expr: Box::new(matches),
82        }
83    } else {
84        matches
85    }
86}
87
88/// Result of executing a subquery
89#[derive(Debug, Clone)]
90pub enum SubqueryResult {
91    /// Scalar subquery returned a single value
92    Scalar(DataValue),
93    /// IN subquery returned a set of values
94    ValueSet(HashSet<DataValue>),
95    /// Subquery returned multiple rows/columns (for future use)
96    Table(Arc<DataView>),
97}
98
99/// Executes subqueries within a SQL statement
100pub struct SubqueryExecutor {
101    query_engine: QueryEngine,
102    source_table: Arc<DataTable>,
103    /// Cache of executed subqueries to avoid re-execution
104    cache: HashMap<String, SubqueryResult>,
105    /// CTE context for resolving CTE references in subqueries
106    cte_context: HashMap<String, Arc<DataView>>,
107}
108
109impl SubqueryExecutor {
110    /// Create a new subquery executor
111    pub fn new(query_engine: QueryEngine, source_table: Arc<DataTable>) -> Self {
112        Self {
113            query_engine,
114            source_table,
115            cache: HashMap::new(),
116            cte_context: HashMap::new(),
117        }
118    }
119
120    /// Create a new subquery executor with CTE context
121    pub fn with_cte_context(
122        query_engine: QueryEngine,
123        source_table: Arc<DataTable>,
124        cte_context: HashMap<String, Arc<DataView>>,
125    ) -> Self {
126        Self {
127            query_engine,
128            source_table,
129            cache: HashMap::new(),
130            cte_context,
131        }
132    }
133
134    /// Execute all subqueries in a statement and return a modified statement
135    /// with subqueries replaced by their results
136    pub fn execute_subqueries(&mut self, statement: &SelectStatement) -> Result<SelectStatement> {
137        info!("SubqueryExecutor: Starting subquery execution pass");
138        info!(
139            "SubqueryExecutor: Available CTEs: {:?}",
140            self.cte_context.keys().collect::<Vec<_>>()
141        );
142
143        // Clone the statement to modify
144        let mut modified_statement = statement.clone();
145
146        // Process WHERE clause if present
147        if let Some(ref where_clause) = statement.where_clause {
148            debug!("SubqueryExecutor: Processing WHERE clause for subqueries");
149            let mut new_conditions = Vec::new();
150            for condition in &where_clause.conditions {
151                new_conditions.push(Condition {
152                    expr: self.process_expression(&condition.expr)?,
153                    connector: condition.connector.clone(),
154                });
155            }
156            modified_statement.where_clause = Some(WhereClause {
157                conditions: new_conditions,
158            });
159        }
160
161        // Process SELECT items
162        let mut new_select_items = Vec::new();
163        for item in &statement.select_items {
164            match item {
165                SelectItem::Column {
166                    column: col,
167                    leading_comments,
168                    trailing_comment,
169                } => {
170                    new_select_items.push(SelectItem::Column {
171                        column: col.clone(),
172                        leading_comments: leading_comments.clone(),
173                        trailing_comment: trailing_comment.clone(),
174                    });
175                }
176                SelectItem::Expression {
177                    expr,
178                    alias,
179                    leading_comments,
180                    trailing_comment,
181                } => {
182                    new_select_items.push(SelectItem::Expression {
183                        expr: self.process_expression(expr)?,
184                        alias: alias.clone(),
185                        leading_comments: leading_comments.clone(),
186                        trailing_comment: trailing_comment.clone(),
187                    });
188                }
189                SelectItem::Star {
190                    table_prefix,
191                    leading_comments,
192                    trailing_comment,
193                } => {
194                    new_select_items.push(SelectItem::Star {
195                        table_prefix: table_prefix.clone(),
196                        leading_comments: leading_comments.clone(),
197                        trailing_comment: trailing_comment.clone(),
198                    });
199                }
200                SelectItem::StarExclude {
201                    table_prefix,
202                    excluded_columns,
203                    leading_comments,
204                    trailing_comment,
205                } => {
206                    new_select_items.push(SelectItem::StarExclude {
207                        table_prefix: table_prefix.clone(),
208                        excluded_columns: excluded_columns.clone(),
209                        leading_comments: leading_comments.clone(),
210                        trailing_comment: trailing_comment.clone(),
211                    });
212                }
213            }
214        }
215        modified_statement.select_items = new_select_items;
216
217        // Process HAVING clause if present
218        if let Some(ref having) = statement.having {
219            debug!("SubqueryExecutor: Processing HAVING clause for subqueries");
220            modified_statement.having = Some(self.process_expression(having)?);
221        }
222
223        debug!("SubqueryExecutor: Subquery execution complete");
224        Ok(modified_statement)
225    }
226
227    /// Process an expression, executing any subqueries and replacing them with results
228    fn process_expression(&mut self, expr: &SqlExpression) -> Result<SqlExpression> {
229        match expr {
230            SqlExpression::ScalarSubquery { query } => {
231                debug!("SubqueryExecutor: Executing scalar subquery");
232                let result = self.execute_scalar_subquery(query)?;
233                Ok(result)
234            }
235
236            SqlExpression::InSubquery { expr, subquery } => {
237                debug!("SubqueryExecutor: Executing IN subquery");
238                let values = self.execute_in_subquery(subquery)?;
239
240                // Replace with InList containing the actual values
241                Ok(SqlExpression::InList {
242                    expr: Box::new(self.process_expression(expr)?),
243                    values: values
244                        .into_iter()
245                        .map(|v| match v {
246                            DataValue::Null => SqlExpression::Null,
247                            DataValue::Integer(i) => SqlExpression::NumberLiteral(i.to_string()),
248                            DataValue::Float(f) => SqlExpression::NumberLiteral(f.to_string()),
249                            DataValue::String(s) => SqlExpression::StringLiteral(s),
250                            DataValue::InternedString(s) => {
251                                SqlExpression::StringLiteral(s.to_string())
252                            }
253                            DataValue::Boolean(b) => SqlExpression::BooleanLiteral(b),
254                            DataValue::DateTime(dt) => SqlExpression::StringLiteral(dt),
255                            DataValue::Vector(v) => {
256                                let components: Vec<String> =
257                                    v.iter().map(|f| f.to_string()).collect();
258                                SqlExpression::StringLiteral(format!("[{}]", components.join(",")))
259                            }
260                        })
261                        .collect(),
262                })
263            }
264
265            SqlExpression::NotInSubquery { expr, subquery } => {
266                debug!("SubqueryExecutor: Executing NOT IN subquery");
267                let values = self.execute_in_subquery(subquery)?;
268
269                // Replace with NotInList containing the actual values
270                Ok(SqlExpression::NotInList {
271                    expr: Box::new(self.process_expression(expr)?),
272                    values: values
273                        .into_iter()
274                        .map(|v| match v {
275                            DataValue::Null => SqlExpression::Null,
276                            DataValue::Integer(i) => SqlExpression::NumberLiteral(i.to_string()),
277                            DataValue::Float(f) => SqlExpression::NumberLiteral(f.to_string()),
278                            DataValue::String(s) => SqlExpression::StringLiteral(s),
279                            DataValue::InternedString(s) => {
280                                SqlExpression::StringLiteral(s.to_string())
281                            }
282                            DataValue::Boolean(b) => SqlExpression::BooleanLiteral(b),
283                            DataValue::DateTime(dt) => SqlExpression::StringLiteral(dt),
284                            DataValue::Vector(v) => {
285                                let components: Vec<String> =
286                                    v.iter().map(|f| f.to_string()).collect();
287                                SqlExpression::StringLiteral(format!("[{}]", components.join(",")))
288                            }
289                        })
290                        .collect(),
291                })
292            }
293
294            SqlExpression::InSubqueryTuple { exprs, subquery } => {
295                debug!("SubqueryExecutor: Executing tuple IN subquery");
296                let processed_exprs: Vec<SqlExpression> = exprs
297                    .iter()
298                    .map(|e| self.process_expression(e))
299                    .collect::<Result<Vec<_>>>()?;
300                let rows = self.execute_tuple_subquery(subquery, exprs.len())?;
301                Ok(build_tuple_in_expression(&processed_exprs, &rows, false))
302            }
303
304            SqlExpression::NotInSubqueryTuple { exprs, subquery } => {
305                debug!("SubqueryExecutor: Executing tuple NOT IN subquery");
306                let processed_exprs: Vec<SqlExpression> = exprs
307                    .iter()
308                    .map(|e| self.process_expression(e))
309                    .collect::<Result<Vec<_>>>()?;
310                let rows = self.execute_tuple_subquery(subquery, exprs.len())?;
311                Ok(build_tuple_in_expression(&processed_exprs, &rows, true))
312            }
313
314            // Process nested expressions
315            SqlExpression::BinaryOp { left, op, right } => Ok(SqlExpression::BinaryOp {
316                left: Box::new(self.process_expression(left)?),
317                op: op.clone(),
318                right: Box::new(self.process_expression(right)?),
319            }),
320
321            // Note: UnaryOp doesn't exist in the current AST, handle negation differently
322            // This case might need to be removed or adapted based on actual AST structure
323            SqlExpression::Between { expr, lower, upper } => Ok(SqlExpression::Between {
324                expr: Box::new(self.process_expression(expr)?),
325                lower: Box::new(self.process_expression(lower)?),
326                upper: Box::new(self.process_expression(upper)?),
327            }),
328
329            SqlExpression::InList { expr, values } => Ok(SqlExpression::InList {
330                expr: Box::new(self.process_expression(expr)?),
331                values: values
332                    .iter()
333                    .map(|v| self.process_expression(v))
334                    .collect::<Result<Vec<_>>>()?,
335            }),
336
337            SqlExpression::NotInList { expr, values } => Ok(SqlExpression::NotInList {
338                expr: Box::new(self.process_expression(expr)?),
339                values: values
340                    .iter()
341                    .map(|v| self.process_expression(v))
342                    .collect::<Result<Vec<_>>>()?,
343            }),
344
345            // CaseWhen doesn't exist in current AST, skip for now
346            SqlExpression::FunctionCall {
347                name,
348                args,
349                distinct,
350            } => Ok(SqlExpression::FunctionCall {
351                name: name.clone(),
352                args: args
353                    .iter()
354                    .map(|a| self.process_expression(a))
355                    .collect::<Result<Vec<_>>>()?,
356                distinct: *distinct,
357            }),
358
359            // Pass through expressions that don't contain subqueries
360            _ => Ok(expr.clone()),
361        }
362    }
363
364    /// Execute a scalar subquery and return a single value
365    fn execute_scalar_subquery(&mut self, query: &SelectStatement) -> Result<SqlExpression> {
366        let cache_key = format!("scalar:{:?}", query);
367
368        // Check cache first
369        if let Some(cached) = self.cache.get(&cache_key) {
370            debug!("SubqueryExecutor: Using cached scalar subquery result");
371            if let SubqueryResult::Scalar(value) = cached {
372                return Ok(self.datavalue_to_expression(value.clone()));
373            }
374        }
375
376        info!("SubqueryExecutor: Executing scalar subquery");
377
378        // Execute the subquery using execute_statement_with_cte_context
379        let result_view = self.query_engine.execute_statement_with_cte_context(
380            self.source_table.clone(),
381            query.clone(),
382            &self.cte_context,
383        )?;
384
385        // Scalar subquery must return exactly one row and one column
386        if result_view.row_count() != 1 {
387            return Err(anyhow!(
388                "Scalar subquery returned {} rows, expected exactly 1",
389                result_view.row_count()
390            ));
391        }
392
393        if result_view.column_count() != 1 {
394            return Err(anyhow!(
395                "Scalar subquery returned {} columns, expected exactly 1",
396                result_view.column_count()
397            ));
398        }
399
400        // Get the single value
401        let value = if let Some(row) = result_view.get_row(0) {
402            row.values.get(0).cloned().unwrap_or(DataValue::Null)
403        } else {
404            DataValue::Null
405        };
406
407        // Cache the result
408        self.cache
409            .insert(cache_key, SubqueryResult::Scalar(value.clone()));
410
411        Ok(self.datavalue_to_expression(value))
412    }
413
414    /// Execute an IN subquery and return a set of values
415    fn execute_in_subquery(&mut self, query: &SelectStatement) -> Result<Vec<DataValue>> {
416        let cache_key = format!("in:{:?}", query);
417
418        // Check cache first
419        if let Some(cached) = self.cache.get(&cache_key) {
420            debug!("SubqueryExecutor: Using cached IN subquery result");
421            if let SubqueryResult::ValueSet(values) = cached {
422                return Ok(values.iter().cloned().collect());
423            }
424        }
425
426        info!("SubqueryExecutor: Executing IN subquery");
427        debug!(
428            "SubqueryExecutor: Available CTEs in context: {:?}",
429            self.cte_context.keys().collect::<Vec<_>>()
430        );
431        debug!("SubqueryExecutor: Subquery: {:?}", query);
432
433        // Execute the subquery using execute_statement_with_cte_context
434        let result_view = self.query_engine.execute_statement_with_cte_context(
435            self.source_table.clone(),
436            query.clone(),
437            &self.cte_context,
438        )?;
439
440        debug!(
441            "SubqueryExecutor: IN subquery returned {} rows",
442            result_view.row_count()
443        );
444
445        // IN subquery must return exactly one column
446        if result_view.column_count() != 1 {
447            return Err(anyhow!(
448                "IN subquery returned {} columns, expected exactly 1",
449                result_view.column_count()
450            ));
451        }
452
453        // Collect all values from the first column
454        let mut values = HashSet::new();
455        for row_idx in 0..result_view.row_count() {
456            if let Some(row) = result_view.get_row(row_idx) {
457                if let Some(value) = row.values.get(0) {
458                    values.insert(value.clone());
459                }
460            }
461        }
462
463        // Cache the result
464        self.cache
465            .insert(cache_key, SubqueryResult::ValueSet(values.clone()));
466
467        Ok(values.into_iter().collect())
468    }
469
470    /// Execute a multi-column subquery for tuple IN, returning rows of values.
471    /// Validates that the number of columns matches the LHS tuple size.
472    fn execute_tuple_subquery(
473        &mut self,
474        query: &SelectStatement,
475        expected_cols: usize,
476    ) -> Result<Vec<Vec<DataValue>>> {
477        info!(
478            "SubqueryExecutor: Executing tuple IN subquery (expecting {} columns)",
479            expected_cols
480        );
481
482        let result_view = self.query_engine.execute_statement_with_cte_context(
483            self.source_table.clone(),
484            query.clone(),
485            &self.cte_context,
486        )?;
487
488        if result_view.column_count() != expected_cols {
489            return Err(anyhow!(
490                "Tuple IN subquery returned {} columns, expected {}",
491                result_view.column_count(),
492                expected_cols
493            ));
494        }
495
496        let mut rows = Vec::with_capacity(result_view.row_count());
497        for row_idx in 0..result_view.row_count() {
498            if let Some(row) = result_view.get_row(row_idx) {
499                rows.push(row.values.clone());
500            }
501        }
502
503        debug!(
504            "SubqueryExecutor: tuple IN subquery returned {} rows",
505            rows.len()
506        );
507        Ok(rows)
508    }
509
510    /// Convert a DataValue to a SqlExpression
511    fn datavalue_to_expression(&self, value: DataValue) -> SqlExpression {
512        match value {
513            DataValue::Null => SqlExpression::Null,
514            DataValue::Integer(i) => SqlExpression::NumberLiteral(i.to_string()),
515            DataValue::Float(f) => SqlExpression::NumberLiteral(f.to_string()),
516            DataValue::String(s) => SqlExpression::StringLiteral(s),
517            DataValue::InternedString(s) => SqlExpression::StringLiteral(s.to_string()),
518            DataValue::Boolean(b) => SqlExpression::BooleanLiteral(b),
519            DataValue::DateTime(dt) => SqlExpression::StringLiteral(dt),
520            DataValue::Vector(v) => {
521                let components: Vec<String> = v.iter().map(|f| f.to_string()).collect();
522                SqlExpression::StringLiteral(format!("[{}]", components.join(",")))
523            }
524        }
525    }
526}