sql_cli/query_plan/
correlated_subquery_analyzer.rs

1//! Correlated Subquery Analysis
2//!
3//! This module analyzes SELECT statements to detect and classify subqueries,
4//! particularly correlated subqueries that reference columns from outer queries.
5//!
6//! # Purpose
7//!
8//! Before we can transform correlated subqueries, we need to understand:
9//! - Where subqueries appear in the query
10//! - Which subqueries are correlated (reference outer query)
11//! - What type of correlation exists (scalar, EXISTS, IN, etc.)
12//! - Which columns are referenced from outer scope
13//!
14//! This analyzer provides visibility into these patterns, which will inform
15//! future transformation strategies.
16
17use crate::sql::parser::ast::{SelectStatement, SqlExpression, WhereClause};
18use std::collections::HashSet;
19
20/// Location where a subquery appears
21#[derive(Debug, Clone, PartialEq)]
22pub enum SubqueryLocation {
23    /// Subquery in FROM clause (derived table)
24    FromClause,
25    /// Subquery in WHERE clause
26    WhereClause,
27    /// Subquery in SELECT list (scalar subquery)
28    SelectList,
29    /// Subquery in HAVING clause
30    HavingClause,
31    /// Subquery in JOIN ON condition
32    JoinCondition,
33}
34
35/// Type of subquery based on its usage
36#[derive(Debug, Clone, PartialEq)]
37pub enum SubqueryType {
38    /// Scalar subquery that returns single value
39    Scalar,
40    /// IN/NOT IN subquery
41    InList { negated: bool },
42    /// EXISTS/NOT EXISTS (not yet parsed, but for future)
43    Exists { negated: bool },
44    /// Subquery in FROM clause
45    DerivedTable,
46}
47
48/// Information about a detected subquery
49#[derive(Debug, Clone)]
50pub struct SubqueryInfo {
51    /// Location where subquery appears
52    pub location: SubqueryLocation,
53    /// Type of subquery
54    pub subquery_type: SubqueryType,
55    /// Whether this subquery references outer query columns
56    pub is_correlated: bool,
57    /// Columns from outer query that are referenced
58    pub outer_references: Vec<String>,
59    /// The subquery statement itself
60    pub statement: SelectStatement,
61}
62
63/// Analysis results for a query
64#[derive(Debug, Default)]
65pub struct CorrelationAnalysis {
66    /// All subqueries found in the query
67    pub subqueries: Vec<SubqueryInfo>,
68    /// Total count of subqueries
69    pub total_count: usize,
70    /// Count of correlated subqueries
71    pub correlated_count: usize,
72    /// Count of non-correlated subqueries
73    pub non_correlated_count: usize,
74}
75
76impl CorrelationAnalysis {
77    /// Generate a human-readable report
78    pub fn report(&self) -> String {
79        let mut report = String::new();
80
81        report.push_str(&format!("=== Subquery Analysis ===\n"));
82        report.push_str(&format!("Total subqueries: {}\n", self.total_count));
83        report.push_str(&format!("  Correlated: {}\n", self.correlated_count));
84        report.push_str(&format!(
85            "  Non-correlated: {}\n",
86            self.non_correlated_count
87        ));
88        report.push_str("\n");
89
90        if self.subqueries.is_empty() {
91            report.push_str("No subqueries detected.\n");
92            return report;
93        }
94
95        for (idx, info) in self.subqueries.iter().enumerate() {
96            report.push_str(&format!("Subquery #{}: ", idx + 1));
97
98            // Location
99            report.push_str(&format!("{:?} - ", info.location));
100
101            // Type
102            report.push_str(&format!("{:?}", info.subquery_type));
103
104            // Correlation status
105            if info.is_correlated {
106                report.push_str(" [CORRELATED]\n");
107                report.push_str(&format!(
108                    "  Outer references: {:?}\n",
109                    info.outer_references
110                ));
111            } else {
112                report.push_str(" [NON-CORRELATED]\n");
113            }
114        }
115
116        report
117    }
118}
119
120/// Analyzer for detecting and classifying correlated subqueries
121pub struct CorrelatedSubqueryAnalyzer {
122    /// Stack of table/alias names available at each nesting level
123    /// Used to determine if a column reference is to an outer query
124    scope_stack: Vec<HashSet<String>>,
125}
126
127impl CorrelatedSubqueryAnalyzer {
128    pub fn new() -> Self {
129        Self {
130            scope_stack: vec![HashSet::new()],
131        }
132    }
133
134    /// Analyze a SELECT statement for subqueries
135    pub fn analyze(&mut self, stmt: &SelectStatement) -> CorrelationAnalysis {
136        let mut analysis = CorrelationAnalysis::default();
137
138        // Collect table/alias names in current scope
139        let mut current_scope = HashSet::new();
140        if let Some(ref table) = stmt.from_table {
141            current_scope.insert(table.clone());
142        }
143        if let Some(ref alias) = stmt.from_alias {
144            current_scope.insert(alias.clone());
145        }
146
147        // Push current scope onto stack
148        self.scope_stack.push(current_scope);
149
150        // Analyze different parts of the query
151        self.analyze_from_clause(stmt, &mut analysis);
152        self.analyze_select_list(stmt, &mut analysis);
153        self.analyze_where_clause(stmt, &mut analysis);
154        self.analyze_having_clause(stmt, &mut analysis);
155
156        // Pop scope
157        self.scope_stack.pop();
158
159        // Update counts
160        analysis.total_count = analysis.subqueries.len();
161        analysis.correlated_count = analysis
162            .subqueries
163            .iter()
164            .filter(|s| s.is_correlated)
165            .count();
166        analysis.non_correlated_count = analysis.total_count - analysis.correlated_count;
167
168        analysis
169    }
170
171    /// Analyze FROM clause for subqueries
172    fn analyze_from_clause(&mut self, stmt: &SelectStatement, analysis: &mut CorrelationAnalysis) {
173        if let Some(ref subquery) = stmt.from_subquery {
174            let outer_refs = self.find_outer_references(subquery);
175
176            analysis.subqueries.push(SubqueryInfo {
177                location: SubqueryLocation::FromClause,
178                subquery_type: SubqueryType::DerivedTable,
179                is_correlated: !outer_refs.is_empty(),
180                outer_references: outer_refs,
181                statement: (**subquery).clone(),
182            });
183        }
184    }
185
186    /// Analyze SELECT list for scalar subqueries
187    fn analyze_select_list(&mut self, stmt: &SelectStatement, analysis: &mut CorrelationAnalysis) {
188        for item in &stmt.select_items {
189            if let crate::sql::parser::ast::SelectItem::Expression { expr, .. } = item {
190                self.analyze_expression_for_subqueries(
191                    expr,
192                    SubqueryLocation::SelectList,
193                    analysis,
194                );
195            }
196        }
197    }
198
199    /// Analyze WHERE clause for subqueries
200    fn analyze_where_clause(&mut self, stmt: &SelectStatement, analysis: &mut CorrelationAnalysis) {
201        if let Some(ref where_clause) = stmt.where_clause {
202            for condition in &where_clause.conditions {
203                self.analyze_expression_for_subqueries(
204                    &condition.expr,
205                    SubqueryLocation::WhereClause,
206                    analysis,
207                );
208            }
209        }
210    }
211
212    /// Analyze HAVING clause for subqueries
213    fn analyze_having_clause(
214        &mut self,
215        stmt: &SelectStatement,
216        analysis: &mut CorrelationAnalysis,
217    ) {
218        if let Some(ref having_expr) = stmt.having {
219            self.analyze_expression_for_subqueries(
220                having_expr,
221                SubqueryLocation::HavingClause,
222                analysis,
223            );
224        }
225    }
226
227    /// Recursively analyze an expression for subqueries
228    fn analyze_expression_for_subqueries(
229        &mut self,
230        expr: &SqlExpression,
231        location: SubqueryLocation,
232        analysis: &mut CorrelationAnalysis,
233    ) {
234        match expr {
235            SqlExpression::ScalarSubquery { query } => {
236                let outer_refs = self.find_outer_references(query);
237
238                analysis.subqueries.push(SubqueryInfo {
239                    location: location.clone(),
240                    subquery_type: SubqueryType::Scalar,
241                    is_correlated: !outer_refs.is_empty(),
242                    outer_references: outer_refs,
243                    statement: (**query).clone(),
244                });
245            }
246
247            SqlExpression::InSubquery { expr: _, subquery } => {
248                let outer_refs = self.find_outer_references(subquery);
249
250                analysis.subqueries.push(SubqueryInfo {
251                    location: location.clone(),
252                    subquery_type: SubqueryType::InList { negated: false },
253                    is_correlated: !outer_refs.is_empty(),
254                    outer_references: outer_refs,
255                    statement: (**subquery).clone(),
256                });
257            }
258
259            SqlExpression::NotInSubquery { expr: _, subquery } => {
260                let outer_refs = self.find_outer_references(subquery);
261
262                analysis.subqueries.push(SubqueryInfo {
263                    location: location.clone(),
264                    subquery_type: SubqueryType::InList { negated: true },
265                    is_correlated: !outer_refs.is_empty(),
266                    outer_references: outer_refs,
267                    statement: (**subquery).clone(),
268                });
269            }
270
271            // Recursively check nested expressions
272            SqlExpression::BinaryOp { left, right, .. } => {
273                self.analyze_expression_for_subqueries(left, location.clone(), analysis);
274                self.analyze_expression_for_subqueries(right, location, analysis);
275            }
276
277            SqlExpression::Not { expr } => {
278                self.analyze_expression_for_subqueries(expr, location, analysis);
279            }
280
281            SqlExpression::FunctionCall { args, .. } => {
282                for arg in args {
283                    self.analyze_expression_for_subqueries(arg, location.clone(), analysis);
284                }
285            }
286
287            _ => {
288                // Other expression types don't contain subqueries
289            }
290        }
291    }
292
293    /// Find column references to outer query (correlation)
294    fn find_outer_references(&self, subquery: &SelectStatement) -> Vec<String> {
295        let mut outer_refs = Vec::new();
296        let mut referenced_tables = HashSet::new();
297
298        // Collect all table references in the subquery
299        self.collect_table_references(subquery, &mut referenced_tables);
300
301        // Check if any referenced tables are from outer scopes
302        for table in &referenced_tables {
303            // Check all outer scopes (everything except current/innermost scope)
304            for scope in self.scope_stack.iter().rev().skip(1) {
305                if scope.contains(table) {
306                    outer_refs.push(table.clone());
307                    break;
308                }
309            }
310        }
311
312        outer_refs.sort();
313        outer_refs.dedup();
314        outer_refs
315    }
316
317    /// Collect all table/alias references in a statement
318    fn collect_table_references(&self, stmt: &SelectStatement, refs: &mut HashSet<String>) {
319        // Check WHERE clause for table-qualified columns
320        if let Some(ref where_clause) = stmt.where_clause {
321            self.collect_references_from_where(where_clause, refs);
322        }
323
324        // Check SELECT list
325        for item in &stmt.select_items {
326            if let crate::sql::parser::ast::SelectItem::Expression { expr, .. } = item {
327                self.collect_references_from_expr(expr, refs);
328            }
329        }
330
331        // Check HAVING
332        if let Some(ref having) = stmt.having {
333            self.collect_references_from_expr(having, refs);
334        }
335    }
336
337    /// Collect table references from WHERE clause
338    fn collect_references_from_where(
339        &self,
340        where_clause: &WhereClause,
341        refs: &mut HashSet<String>,
342    ) {
343        for condition in &where_clause.conditions {
344            self.collect_references_from_expr(&condition.expr, refs);
345        }
346    }
347
348    /// Collect table references from expression
349    fn collect_references_from_expr(&self, expr: &SqlExpression, refs: &mut HashSet<String>) {
350        match expr {
351            SqlExpression::Column(col_ref) => {
352                if let Some(ref table) = col_ref.table_prefix {
353                    refs.insert(table.clone());
354                }
355            }
356
357            SqlExpression::BinaryOp { left, right, .. } => {
358                self.collect_references_from_expr(left, refs);
359                self.collect_references_from_expr(right, refs);
360            }
361
362            SqlExpression::Not { expr } => {
363                self.collect_references_from_expr(expr, refs);
364            }
365
366            SqlExpression::FunctionCall { args, .. } => {
367                for arg in args {
368                    self.collect_references_from_expr(arg, refs);
369                }
370            }
371
372            SqlExpression::InList { expr, values } => {
373                self.collect_references_from_expr(expr, refs);
374                for val in values {
375                    self.collect_references_from_expr(val, refs);
376                }
377            }
378
379            SqlExpression::NotInList { expr, values } => {
380                self.collect_references_from_expr(expr, refs);
381                for val in values {
382                    self.collect_references_from_expr(val, refs);
383                }
384            }
385
386            _ => {
387                // Other expression types
388            }
389        }
390    }
391}
392
393impl Default for CorrelatedSubqueryAnalyzer {
394    fn default() -> Self {
395        Self::new()
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use crate::sql::parser::ast::{Condition, QuoteStyle};
403
404    #[test]
405    fn test_non_correlated_scalar_subquery() {
406        let mut analyzer = CorrelatedSubqueryAnalyzer::new();
407
408        // Main query with non-correlated subquery
409        let main_stmt = SelectStatement {
410            from_table: Some("trades".to_string()),
411            ..Default::default()
412        };
413
414        let analysis = analyzer.analyze(&main_stmt);
415        assert_eq!(analysis.total_count, 0);
416    }
417
418    #[test]
419    fn test_from_clause_subquery() {
420        let mut analyzer = CorrelatedSubqueryAnalyzer::new();
421
422        let subquery = SelectStatement {
423            from_table: Some("inner_table".to_string()),
424            ..Default::default()
425        };
426
427        let main_stmt = SelectStatement {
428            from_subquery: Some(Box::new(subquery)),
429            from_alias: Some("sub".to_string()),
430            ..Default::default()
431        };
432
433        let analysis = analyzer.analyze(&main_stmt);
434        assert_eq!(analysis.total_count, 1);
435        assert_eq!(
436            analysis.subqueries[0].location,
437            SubqueryLocation::FromClause
438        );
439        assert_eq!(
440            analysis.subqueries[0].subquery_type,
441            SubqueryType::DerivedTable
442        );
443        assert!(!analysis.subqueries[0].is_correlated);
444    }
445
446    #[test]
447    fn test_analysis_report_format() {
448        let analysis = CorrelationAnalysis {
449            subqueries: vec![],
450            total_count: 0,
451            correlated_count: 0,
452            non_correlated_count: 0,
453        };
454
455        let report = analysis.report();
456        assert!(report.contains("Subquery Analysis"));
457        assert!(report.contains("No subqueries detected"));
458    }
459}