Skip to main content

sql_cli/data/
subquery_executor.rs

1// Subquery execution handler
2// Walks the AST to evaluate subqueries and replace them with their results
3
4use crate::data::data_view::DataView;
5use crate::data::datatable::{DataTable, DataValue};
6use crate::data::query_engine::QueryEngine;
7use crate::sql::parser::ast::{
8    Condition, SelectItem, SelectStatement, SimpleWhenBranch, SqlExpression, WhenBranch,
9    WhereClause,
10};
11use anyhow::{anyhow, Result};
12use std::collections::{HashMap, HashSet};
13use std::sync::Arc;
14use tracing::{debug, info};
15
16/// Convert a DataValue back into a SqlExpression literal (used when materialising
17/// subquery results into the AST for the evaluator to handle).
18fn datavalue_to_literal(v: &DataValue) -> SqlExpression {
19    match v {
20        DataValue::Null => SqlExpression::Null,
21        DataValue::Integer(i) => SqlExpression::NumberLiteral(i.to_string()),
22        DataValue::Float(f) => SqlExpression::NumberLiteral(f.to_string()),
23        DataValue::String(s) => SqlExpression::StringLiteral(s.clone()),
24        DataValue::InternedString(s) => SqlExpression::StringLiteral(s.to_string()),
25        DataValue::Boolean(b) => SqlExpression::BooleanLiteral(*b),
26        DataValue::DateTime(dt) => SqlExpression::StringLiteral(dt.clone()),
27        DataValue::Vector(vec) => {
28            let components: Vec<String> = vec.iter().map(|f| f.to_string()).collect();
29            SqlExpression::StringLiteral(format!("[{}]", components.join(",")))
30        }
31    }
32}
33
34/// Build an expression equivalent to `(e1, e2, ...) IN (v_row1, v_row2, ...)`
35/// expanded as OR of AND equality checks. For an empty subquery result the
36/// expression evaluates to FALSE (true for NOT IN).
37///
38/// Example: `(a, b) IN (SELECT x, y ...)` with rows [(1,2), (3,4)] becomes
39///     (a = 1 AND b = 2) OR (a = 3 AND b = 4)
40pub(crate) fn build_tuple_in_expression(
41    exprs: &[SqlExpression],
42    rows: &[Vec<DataValue>],
43    negate: bool,
44) -> SqlExpression {
45    if rows.is_empty() {
46        // Empty subquery: IN is always false, NOT IN is always true
47        return SqlExpression::BooleanLiteral(negate);
48    }
49
50    // Build OR of row-matches. Each row-match is AND of column equalities.
51    let mut or_expr: Option<SqlExpression> = None;
52    for row in rows {
53        // Build AND of equalities for this row
54        let mut and_expr: Option<SqlExpression> = None;
55        for (i, value) in row.iter().enumerate() {
56            let eq = SqlExpression::BinaryOp {
57                left: Box::new(exprs[i].clone()),
58                op: "=".to_string(),
59                right: Box::new(datavalue_to_literal(value)),
60            };
61            and_expr = Some(match and_expr {
62                None => eq,
63                Some(prev) => SqlExpression::BinaryOp {
64                    left: Box::new(prev),
65                    op: "AND".to_string(),
66                    right: Box::new(eq),
67                },
68            });
69        }
70        let row_match = and_expr.expect("row had zero columns — should not happen");
71        or_expr = Some(match or_expr {
72            None => row_match,
73            Some(prev) => SqlExpression::BinaryOp {
74                left: Box::new(prev),
75                op: "OR".to_string(),
76                right: Box::new(row_match),
77            },
78        });
79    }
80
81    let matches = or_expr.expect("rows was non-empty");
82    if negate {
83        SqlExpression::Not {
84            expr: Box::new(matches),
85        }
86    } else {
87        matches
88    }
89}
90
91/// Result of executing a subquery
92#[derive(Debug, Clone)]
93pub enum SubqueryResult {
94    /// Scalar subquery returned a single value
95    Scalar(DataValue),
96    /// IN subquery returned a set of values
97    ValueSet(HashSet<DataValue>),
98    /// Subquery returned multiple rows/columns (for future use)
99    Table(Arc<DataView>),
100}
101
102/// Executes subqueries within a SQL statement
103pub struct SubqueryExecutor {
104    query_engine: QueryEngine,
105    source_table: Arc<DataTable>,
106    /// Cache of executed subqueries to avoid re-execution
107    cache: HashMap<String, SubqueryResult>,
108    /// CTE context for resolving CTE references in subqueries
109    cte_context: HashMap<String, Arc<DataView>>,
110}
111
112impl SubqueryExecutor {
113    /// Create a new subquery executor
114    pub fn new(query_engine: QueryEngine, source_table: Arc<DataTable>) -> Self {
115        Self {
116            query_engine,
117            source_table,
118            cache: HashMap::new(),
119            cte_context: HashMap::new(),
120        }
121    }
122
123    /// Create a new subquery executor with CTE context
124    pub fn with_cte_context(
125        query_engine: QueryEngine,
126        source_table: Arc<DataTable>,
127        cte_context: HashMap<String, Arc<DataView>>,
128    ) -> Self {
129        Self {
130            query_engine,
131            source_table,
132            cache: HashMap::new(),
133            cte_context,
134        }
135    }
136
137    /// Execute all subqueries in a statement and return a modified statement
138    /// with subqueries replaced by their results
139    pub fn execute_subqueries(&mut self, statement: &SelectStatement) -> Result<SelectStatement> {
140        info!("SubqueryExecutor: Starting subquery execution pass");
141        info!(
142            "SubqueryExecutor: Available CTEs: {:?}",
143            self.cte_context.keys().collect::<Vec<_>>()
144        );
145
146        // Clone the statement to modify
147        let mut modified_statement = statement.clone();
148
149        // Process WHERE clause if present
150        if let Some(ref where_clause) = statement.where_clause {
151            debug!("SubqueryExecutor: Processing WHERE clause for subqueries");
152            let mut new_conditions = Vec::new();
153            for condition in &where_clause.conditions {
154                new_conditions.push(Condition {
155                    expr: self.process_expression(&condition.expr)?,
156                    connector: condition.connector.clone(),
157                });
158            }
159            modified_statement.where_clause = Some(WhereClause {
160                conditions: new_conditions,
161            });
162        }
163
164        // Process SELECT items
165        let mut new_select_items = Vec::new();
166        for item in &statement.select_items {
167            match item {
168                SelectItem::Column {
169                    column: col,
170                    leading_comments,
171                    trailing_comment,
172                } => {
173                    new_select_items.push(SelectItem::Column {
174                        column: col.clone(),
175                        leading_comments: leading_comments.clone(),
176                        trailing_comment: trailing_comment.clone(),
177                    });
178                }
179                SelectItem::Expression {
180                    expr,
181                    alias,
182                    leading_comments,
183                    trailing_comment,
184                } => {
185                    new_select_items.push(SelectItem::Expression {
186                        expr: self.process_expression(expr)?,
187                        alias: alias.clone(),
188                        leading_comments: leading_comments.clone(),
189                        trailing_comment: trailing_comment.clone(),
190                    });
191                }
192                SelectItem::Star {
193                    table_prefix,
194                    leading_comments,
195                    trailing_comment,
196                } => {
197                    new_select_items.push(SelectItem::Star {
198                        table_prefix: table_prefix.clone(),
199                        leading_comments: leading_comments.clone(),
200                        trailing_comment: trailing_comment.clone(),
201                    });
202                }
203                SelectItem::StarExclude {
204                    table_prefix,
205                    excluded_columns,
206                    leading_comments,
207                    trailing_comment,
208                } => {
209                    new_select_items.push(SelectItem::StarExclude {
210                        table_prefix: table_prefix.clone(),
211                        excluded_columns: excluded_columns.clone(),
212                        leading_comments: leading_comments.clone(),
213                        trailing_comment: trailing_comment.clone(),
214                    });
215                }
216            }
217        }
218        modified_statement.select_items = new_select_items;
219
220        // Process HAVING clause if present
221        if let Some(ref having) = statement.having {
222            debug!("SubqueryExecutor: Processing HAVING clause for subqueries");
223            modified_statement.having = Some(self.process_expression(having)?);
224        }
225
226        debug!("SubqueryExecutor: Subquery execution complete");
227        Ok(modified_statement)
228    }
229
230    /// Process an expression, executing any subqueries and replacing them with results
231    fn process_expression(&mut self, expr: &SqlExpression) -> Result<SqlExpression> {
232        match expr {
233            SqlExpression::ScalarSubquery { query } => {
234                debug!("SubqueryExecutor: Executing scalar subquery");
235                let result = self.execute_scalar_subquery(query)?;
236                Ok(result)
237            }
238
239            SqlExpression::InSubquery { expr, subquery } => {
240                debug!("SubqueryExecutor: Executing IN subquery");
241                let values = self.execute_in_subquery(subquery)?;
242
243                // Replace with InList containing the actual values
244                Ok(SqlExpression::InList {
245                    expr: Box::new(self.process_expression(expr)?),
246                    values: values
247                        .into_iter()
248                        .map(|v| match v {
249                            DataValue::Null => SqlExpression::Null,
250                            DataValue::Integer(i) => SqlExpression::NumberLiteral(i.to_string()),
251                            DataValue::Float(f) => SqlExpression::NumberLiteral(f.to_string()),
252                            DataValue::String(s) => SqlExpression::StringLiteral(s),
253                            DataValue::InternedString(s) => {
254                                SqlExpression::StringLiteral(s.to_string())
255                            }
256                            DataValue::Boolean(b) => SqlExpression::BooleanLiteral(b),
257                            DataValue::DateTime(dt) => SqlExpression::StringLiteral(dt),
258                            DataValue::Vector(v) => {
259                                let components: Vec<String> =
260                                    v.iter().map(|f| f.to_string()).collect();
261                                SqlExpression::StringLiteral(format!("[{}]", components.join(",")))
262                            }
263                        })
264                        .collect(),
265                })
266            }
267
268            SqlExpression::NotInSubquery { expr, subquery } => {
269                debug!("SubqueryExecutor: Executing NOT IN subquery");
270                let values = self.execute_in_subquery(subquery)?;
271
272                // Replace with NotInList containing the actual values
273                Ok(SqlExpression::NotInList {
274                    expr: Box::new(self.process_expression(expr)?),
275                    values: values
276                        .into_iter()
277                        .map(|v| match v {
278                            DataValue::Null => SqlExpression::Null,
279                            DataValue::Integer(i) => SqlExpression::NumberLiteral(i.to_string()),
280                            DataValue::Float(f) => SqlExpression::NumberLiteral(f.to_string()),
281                            DataValue::String(s) => SqlExpression::StringLiteral(s),
282                            DataValue::InternedString(s) => {
283                                SqlExpression::StringLiteral(s.to_string())
284                            }
285                            DataValue::Boolean(b) => SqlExpression::BooleanLiteral(b),
286                            DataValue::DateTime(dt) => SqlExpression::StringLiteral(dt),
287                            DataValue::Vector(v) => {
288                                let components: Vec<String> =
289                                    v.iter().map(|f| f.to_string()).collect();
290                                SqlExpression::StringLiteral(format!("[{}]", components.join(",")))
291                            }
292                        })
293                        .collect(),
294                })
295            }
296
297            SqlExpression::InSubqueryTuple { exprs, subquery } => {
298                debug!("SubqueryExecutor: Executing tuple IN subquery");
299                let processed_exprs: Vec<SqlExpression> = exprs
300                    .iter()
301                    .map(|e| self.process_expression(e))
302                    .collect::<Result<Vec<_>>>()?;
303                let rows = self.execute_tuple_subquery(subquery, exprs.len())?;
304                Ok(build_tuple_in_expression(&processed_exprs, &rows, false))
305            }
306
307            SqlExpression::NotInSubqueryTuple { exprs, subquery } => {
308                debug!("SubqueryExecutor: Executing tuple NOT IN subquery");
309                let processed_exprs: Vec<SqlExpression> = exprs
310                    .iter()
311                    .map(|e| self.process_expression(e))
312                    .collect::<Result<Vec<_>>>()?;
313                let rows = self.execute_tuple_subquery(subquery, exprs.len())?;
314                Ok(build_tuple_in_expression(&processed_exprs, &rows, true))
315            }
316
317            // Process nested expressions
318            SqlExpression::BinaryOp { left, op, right } => Ok(SqlExpression::BinaryOp {
319                left: Box::new(self.process_expression(left)?),
320                op: op.clone(),
321                right: Box::new(self.process_expression(right)?),
322            }),
323
324            // Note: UnaryOp doesn't exist in the current AST, handle negation differently
325            // This case might need to be removed or adapted based on actual AST structure
326            SqlExpression::Between { expr, lower, upper } => Ok(SqlExpression::Between {
327                expr: Box::new(self.process_expression(expr)?),
328                lower: Box::new(self.process_expression(lower)?),
329                upper: Box::new(self.process_expression(upper)?),
330            }),
331
332            SqlExpression::InList { expr, values } => Ok(SqlExpression::InList {
333                expr: Box::new(self.process_expression(expr)?),
334                values: values
335                    .iter()
336                    .map(|v| self.process_expression(v))
337                    .collect::<Result<Vec<_>>>()?,
338            }),
339
340            SqlExpression::NotInList { expr, values } => Ok(SqlExpression::NotInList {
341                expr: Box::new(self.process_expression(expr)?),
342                values: values
343                    .iter()
344                    .map(|v| self.process_expression(v))
345                    .collect::<Result<Vec<_>>>()?,
346            }),
347
348            SqlExpression::FunctionCall {
349                name,
350                args,
351                distinct,
352            } => Ok(SqlExpression::FunctionCall {
353                name: name.clone(),
354                args: args
355                    .iter()
356                    .map(|a| self.process_expression(a))
357                    .collect::<Result<Vec<_>>>()?,
358                distinct: *distinct,
359            }),
360
361            // Searched CASE: subqueries can appear in any WHEN condition, any
362            // result, or the ELSE branch (e.g. THEN (SELECT MAX(x) FROM ...)).
363            // Recurse into all of them so they are pre-executed like in WHERE.
364            SqlExpression::CaseExpression {
365                when_branches,
366                else_branch,
367            } => Ok(SqlExpression::CaseExpression {
368                when_branches: when_branches
369                    .iter()
370                    .map(|b| {
371                        Ok(WhenBranch {
372                            condition: Box::new(self.process_expression(&b.condition)?),
373                            result: Box::new(self.process_expression(&b.result)?),
374                        })
375                    })
376                    .collect::<Result<Vec<_>>>()?,
377                else_branch: match else_branch {
378                    Some(e) => Some(Box::new(self.process_expression(e)?)),
379                    None => None,
380                },
381            }),
382
383            // Simple CASE: CASE expr WHEN value THEN result ...
384            SqlExpression::SimpleCaseExpression {
385                expr,
386                when_branches,
387                else_branch,
388            } => Ok(SqlExpression::SimpleCaseExpression {
389                expr: Box::new(self.process_expression(expr)?),
390                when_branches: when_branches
391                    .iter()
392                    .map(|b| {
393                        Ok(SimpleWhenBranch {
394                            value: Box::new(self.process_expression(&b.value)?),
395                            result: Box::new(self.process_expression(&b.result)?),
396                        })
397                    })
398                    .collect::<Result<Vec<_>>>()?,
399                else_branch: match else_branch {
400                    Some(e) => Some(Box::new(self.process_expression(e)?)),
401                    None => None,
402                },
403            }),
404
405            SqlExpression::Not { expr } => Ok(SqlExpression::Not {
406                expr: Box::new(self.process_expression(expr)?),
407            }),
408
409            // Pass through expressions that don't contain subqueries
410            _ => Ok(expr.clone()),
411        }
412    }
413
414    /// Execute a scalar subquery and return a single value
415    fn execute_scalar_subquery(&mut self, query: &SelectStatement) -> Result<SqlExpression> {
416        let cache_key = format!("scalar:{:?}", query);
417
418        // Check cache first
419        if let Some(cached) = self.cache.get(&cache_key) {
420            debug!("SubqueryExecutor: Using cached scalar subquery result");
421            if let SubqueryResult::Scalar(value) = cached {
422                return Ok(self.datavalue_to_expression(value.clone()));
423            }
424        }
425
426        info!("SubqueryExecutor: Executing scalar subquery");
427
428        // Execute the subquery using execute_statement_with_cte_context
429        let result_view = self.query_engine.execute_statement_with_cte_context(
430            self.source_table.clone(),
431            query.clone(),
432            &self.cte_context,
433        )?;
434
435        // Scalar subquery must return exactly one row and one column
436        if result_view.row_count() != 1 {
437            return Err(anyhow!(
438                "Scalar subquery returned {} rows, expected exactly 1",
439                result_view.row_count()
440            ));
441        }
442
443        if result_view.column_count() != 1 {
444            return Err(anyhow!(
445                "Scalar subquery returned {} columns, expected exactly 1",
446                result_view.column_count()
447            ));
448        }
449
450        // Get the single value
451        let value = if let Some(row) = result_view.get_row(0) {
452            row.values.get(0).cloned().unwrap_or(DataValue::Null)
453        } else {
454            DataValue::Null
455        };
456
457        // Cache the result
458        self.cache
459            .insert(cache_key, SubqueryResult::Scalar(value.clone()));
460
461        Ok(self.datavalue_to_expression(value))
462    }
463
464    /// Execute an IN subquery and return a set of values
465    fn execute_in_subquery(&mut self, query: &SelectStatement) -> Result<Vec<DataValue>> {
466        let cache_key = format!("in:{:?}", query);
467
468        // Check cache first
469        if let Some(cached) = self.cache.get(&cache_key) {
470            debug!("SubqueryExecutor: Using cached IN subquery result");
471            if let SubqueryResult::ValueSet(values) = cached {
472                return Ok(values.iter().cloned().collect());
473            }
474        }
475
476        info!("SubqueryExecutor: Executing IN subquery");
477        debug!(
478            "SubqueryExecutor: Available CTEs in context: {:?}",
479            self.cte_context.keys().collect::<Vec<_>>()
480        );
481        debug!("SubqueryExecutor: Subquery: {:?}", query);
482
483        // Execute the subquery using execute_statement_with_cte_context
484        let result_view = self.query_engine.execute_statement_with_cte_context(
485            self.source_table.clone(),
486            query.clone(),
487            &self.cte_context,
488        )?;
489
490        debug!(
491            "SubqueryExecutor: IN subquery returned {} rows",
492            result_view.row_count()
493        );
494
495        // IN subquery must return exactly one column
496        if result_view.column_count() != 1 {
497            return Err(anyhow!(
498                "IN subquery returned {} columns, expected exactly 1",
499                result_view.column_count()
500            ));
501        }
502
503        // Collect all values from the first column
504        let mut values = HashSet::new();
505        for row_idx in 0..result_view.row_count() {
506            if let Some(row) = result_view.get_row(row_idx) {
507                if let Some(value) = row.values.get(0) {
508                    values.insert(value.clone());
509                }
510            }
511        }
512
513        // Cache the result
514        self.cache
515            .insert(cache_key, SubqueryResult::ValueSet(values.clone()));
516
517        Ok(values.into_iter().collect())
518    }
519
520    /// Execute a multi-column subquery for tuple IN, returning rows of values.
521    /// Validates that the number of columns matches the LHS tuple size.
522    fn execute_tuple_subquery(
523        &mut self,
524        query: &SelectStatement,
525        expected_cols: usize,
526    ) -> Result<Vec<Vec<DataValue>>> {
527        info!(
528            "SubqueryExecutor: Executing tuple IN subquery (expecting {} columns)",
529            expected_cols
530        );
531
532        let result_view = self.query_engine.execute_statement_with_cte_context(
533            self.source_table.clone(),
534            query.clone(),
535            &self.cte_context,
536        )?;
537
538        if result_view.column_count() != expected_cols {
539            return Err(anyhow!(
540                "Tuple IN subquery returned {} columns, expected {}",
541                result_view.column_count(),
542                expected_cols
543            ));
544        }
545
546        let mut rows = Vec::with_capacity(result_view.row_count());
547        for row_idx in 0..result_view.row_count() {
548            if let Some(row) = result_view.get_row(row_idx) {
549                rows.push(row.values.clone());
550            }
551        }
552
553        debug!(
554            "SubqueryExecutor: tuple IN subquery returned {} rows",
555            rows.len()
556        );
557        Ok(rows)
558    }
559
560    /// Convert a DataValue to a SqlExpression
561    fn datavalue_to_expression(&self, value: DataValue) -> SqlExpression {
562        match value {
563            DataValue::Null => SqlExpression::Null,
564            DataValue::Integer(i) => SqlExpression::NumberLiteral(i.to_string()),
565            DataValue::Float(f) => SqlExpression::NumberLiteral(f.to_string()),
566            DataValue::String(s) => SqlExpression::StringLiteral(s),
567            DataValue::InternedString(s) => SqlExpression::StringLiteral(s.to_string()),
568            DataValue::Boolean(b) => SqlExpression::BooleanLiteral(b),
569            DataValue::DateTime(dt) => SqlExpression::StringLiteral(dt),
570            DataValue::Vector(v) => {
571                let components: Vec<String> = v.iter().map(|f| f.to_string()).collect();
572                SqlExpression::StringLiteral(format!("[{}]", components.join(",")))
573            }
574        }
575    }
576}