1use crate::sql::parser::ast::{SelectStatement, SqlExpression, WhereClause};
18use std::collections::HashSet;
19
20#[derive(Debug, Clone, PartialEq)]
22pub enum SubqueryLocation {
23 FromClause,
25 WhereClause,
27 SelectList,
29 HavingClause,
31 JoinCondition,
33}
34
35#[derive(Debug, Clone, PartialEq)]
37pub enum SubqueryType {
38 Scalar,
40 InList { negated: bool },
42 Exists { negated: bool },
44 DerivedTable,
46}
47
48#[derive(Debug, Clone)]
50pub struct SubqueryInfo {
51 pub location: SubqueryLocation,
53 pub subquery_type: SubqueryType,
55 pub is_correlated: bool,
57 pub outer_references: Vec<String>,
59 pub statement: SelectStatement,
61}
62
63#[derive(Debug, Default)]
65pub struct CorrelationAnalysis {
66 pub subqueries: Vec<SubqueryInfo>,
68 pub total_count: usize,
70 pub correlated_count: usize,
72 pub non_correlated_count: usize,
74}
75
76impl CorrelationAnalysis {
77 pub fn report(&self) -> String {
79 let mut report = String::new();
80
81 report.push_str(&format!("=== Subquery Analysis ===\n"));
82 report.push_str(&format!("Total subqueries: {}\n", self.total_count));
83 report.push_str(&format!(" Correlated: {}\n", self.correlated_count));
84 report.push_str(&format!(
85 " Non-correlated: {}\n",
86 self.non_correlated_count
87 ));
88 report.push_str("\n");
89
90 if self.subqueries.is_empty() {
91 report.push_str("No subqueries detected.\n");
92 return report;
93 }
94
95 for (idx, info) in self.subqueries.iter().enumerate() {
96 report.push_str(&format!("Subquery #{}: ", idx + 1));
97
98 report.push_str(&format!("{:?} - ", info.location));
100
101 report.push_str(&format!("{:?}", info.subquery_type));
103
104 if info.is_correlated {
106 report.push_str(" [CORRELATED]\n");
107 report.push_str(&format!(
108 " Outer references: {:?}\n",
109 info.outer_references
110 ));
111 } else {
112 report.push_str(" [NON-CORRELATED]\n");
113 }
114 }
115
116 report
117 }
118}
119
120pub struct CorrelatedSubqueryAnalyzer {
122 scope_stack: Vec<HashSet<String>>,
125}
126
127impl CorrelatedSubqueryAnalyzer {
128 pub fn new() -> Self {
129 Self {
130 scope_stack: vec![HashSet::new()],
131 }
132 }
133
134 pub fn analyze(&mut self, stmt: &SelectStatement) -> CorrelationAnalysis {
136 let mut analysis = CorrelationAnalysis::default();
137
138 let mut current_scope = HashSet::new();
140 if let Some(ref table) = stmt.from_table {
141 current_scope.insert(table.clone());
142 }
143 if let Some(ref alias) = stmt.from_alias {
144 current_scope.insert(alias.clone());
145 }
146
147 self.scope_stack.push(current_scope);
149
150 self.analyze_from_clause(stmt, &mut analysis);
152 self.analyze_select_list(stmt, &mut analysis);
153 self.analyze_where_clause(stmt, &mut analysis);
154 self.analyze_having_clause(stmt, &mut analysis);
155
156 self.scope_stack.pop();
158
159 analysis.total_count = analysis.subqueries.len();
161 analysis.correlated_count = analysis
162 .subqueries
163 .iter()
164 .filter(|s| s.is_correlated)
165 .count();
166 analysis.non_correlated_count = analysis.total_count - analysis.correlated_count;
167
168 analysis
169 }
170
171 fn analyze_from_clause(&mut self, stmt: &SelectStatement, analysis: &mut CorrelationAnalysis) {
173 if let Some(ref subquery) = stmt.from_subquery {
174 let outer_refs = self.find_outer_references(subquery);
175
176 analysis.subqueries.push(SubqueryInfo {
177 location: SubqueryLocation::FromClause,
178 subquery_type: SubqueryType::DerivedTable,
179 is_correlated: !outer_refs.is_empty(),
180 outer_references: outer_refs,
181 statement: (**subquery).clone(),
182 });
183 }
184 }
185
186 fn analyze_select_list(&mut self, stmt: &SelectStatement, analysis: &mut CorrelationAnalysis) {
188 for item in &stmt.select_items {
189 if let crate::sql::parser::ast::SelectItem::Expression { expr, .. } = item {
190 self.analyze_expression_for_subqueries(
191 expr,
192 SubqueryLocation::SelectList,
193 analysis,
194 );
195 }
196 }
197 }
198
199 fn analyze_where_clause(&mut self, stmt: &SelectStatement, analysis: &mut CorrelationAnalysis) {
201 if let Some(ref where_clause) = stmt.where_clause {
202 for condition in &where_clause.conditions {
203 self.analyze_expression_for_subqueries(
204 &condition.expr,
205 SubqueryLocation::WhereClause,
206 analysis,
207 );
208 }
209 }
210 }
211
212 fn analyze_having_clause(
214 &mut self,
215 stmt: &SelectStatement,
216 analysis: &mut CorrelationAnalysis,
217 ) {
218 if let Some(ref having_expr) = stmt.having {
219 self.analyze_expression_for_subqueries(
220 having_expr,
221 SubqueryLocation::HavingClause,
222 analysis,
223 );
224 }
225 }
226
227 fn analyze_expression_for_subqueries(
229 &mut self,
230 expr: &SqlExpression,
231 location: SubqueryLocation,
232 analysis: &mut CorrelationAnalysis,
233 ) {
234 match expr {
235 SqlExpression::ScalarSubquery { query } => {
236 let outer_refs = self.find_outer_references(query);
237
238 analysis.subqueries.push(SubqueryInfo {
239 location: location.clone(),
240 subquery_type: SubqueryType::Scalar,
241 is_correlated: !outer_refs.is_empty(),
242 outer_references: outer_refs,
243 statement: (**query).clone(),
244 });
245 }
246
247 SqlExpression::InSubquery { expr: _, subquery } => {
248 let outer_refs = self.find_outer_references(subquery);
249
250 analysis.subqueries.push(SubqueryInfo {
251 location: location.clone(),
252 subquery_type: SubqueryType::InList { negated: false },
253 is_correlated: !outer_refs.is_empty(),
254 outer_references: outer_refs,
255 statement: (**subquery).clone(),
256 });
257 }
258
259 SqlExpression::NotInSubquery { expr: _, subquery } => {
260 let outer_refs = self.find_outer_references(subquery);
261
262 analysis.subqueries.push(SubqueryInfo {
263 location: location.clone(),
264 subquery_type: SubqueryType::InList { negated: true },
265 is_correlated: !outer_refs.is_empty(),
266 outer_references: outer_refs,
267 statement: (**subquery).clone(),
268 });
269 }
270
271 SqlExpression::BinaryOp { left, right, .. } => {
273 self.analyze_expression_for_subqueries(left, location.clone(), analysis);
274 self.analyze_expression_for_subqueries(right, location, analysis);
275 }
276
277 SqlExpression::Not { expr } => {
278 self.analyze_expression_for_subqueries(expr, location, analysis);
279 }
280
281 SqlExpression::FunctionCall { args, .. } => {
282 for arg in args {
283 self.analyze_expression_for_subqueries(arg, location.clone(), analysis);
284 }
285 }
286
287 _ => {
288 }
290 }
291 }
292
293 fn find_outer_references(&self, subquery: &SelectStatement) -> Vec<String> {
295 let mut outer_refs = Vec::new();
296 let mut referenced_tables = HashSet::new();
297
298 self.collect_table_references(subquery, &mut referenced_tables);
300
301 for table in &referenced_tables {
303 for scope in self.scope_stack.iter().rev().skip(1) {
305 if scope.contains(table) {
306 outer_refs.push(table.clone());
307 break;
308 }
309 }
310 }
311
312 outer_refs.sort();
313 outer_refs.dedup();
314 outer_refs
315 }
316
317 fn collect_table_references(&self, stmt: &SelectStatement, refs: &mut HashSet<String>) {
319 if let Some(ref where_clause) = stmt.where_clause {
321 self.collect_references_from_where(where_clause, refs);
322 }
323
324 for item in &stmt.select_items {
326 if let crate::sql::parser::ast::SelectItem::Expression { expr, .. } = item {
327 self.collect_references_from_expr(expr, refs);
328 }
329 }
330
331 if let Some(ref having) = stmt.having {
333 self.collect_references_from_expr(having, refs);
334 }
335 }
336
337 fn collect_references_from_where(
339 &self,
340 where_clause: &WhereClause,
341 refs: &mut HashSet<String>,
342 ) {
343 for condition in &where_clause.conditions {
344 self.collect_references_from_expr(&condition.expr, refs);
345 }
346 }
347
348 fn collect_references_from_expr(&self, expr: &SqlExpression, refs: &mut HashSet<String>) {
350 match expr {
351 SqlExpression::Column(col_ref) => {
352 if let Some(ref table) = col_ref.table_prefix {
353 refs.insert(table.clone());
354 }
355 }
356
357 SqlExpression::BinaryOp { left, right, .. } => {
358 self.collect_references_from_expr(left, refs);
359 self.collect_references_from_expr(right, refs);
360 }
361
362 SqlExpression::Not { expr } => {
363 self.collect_references_from_expr(expr, refs);
364 }
365
366 SqlExpression::FunctionCall { args, .. } => {
367 for arg in args {
368 self.collect_references_from_expr(arg, refs);
369 }
370 }
371
372 SqlExpression::InList { expr, values } => {
373 self.collect_references_from_expr(expr, refs);
374 for val in values {
375 self.collect_references_from_expr(val, refs);
376 }
377 }
378
379 SqlExpression::NotInList { expr, values } => {
380 self.collect_references_from_expr(expr, refs);
381 for val in values {
382 self.collect_references_from_expr(val, refs);
383 }
384 }
385
386 _ => {
387 }
389 }
390 }
391}
392
393impl Default for CorrelatedSubqueryAnalyzer {
394 fn default() -> Self {
395 Self::new()
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use crate::sql::parser::ast::{Condition, QuoteStyle};
403
404 #[test]
405 fn test_non_correlated_scalar_subquery() {
406 let mut analyzer = CorrelatedSubqueryAnalyzer::new();
407
408 let main_stmt = SelectStatement {
410 from_table: Some("trades".to_string()),
411 ..Default::default()
412 };
413
414 let analysis = analyzer.analyze(&main_stmt);
415 assert_eq!(analysis.total_count, 0);
416 }
417
418 #[test]
419 fn test_from_clause_subquery() {
420 let mut analyzer = CorrelatedSubqueryAnalyzer::new();
421
422 let subquery = SelectStatement {
423 from_table: Some("inner_table".to_string()),
424 ..Default::default()
425 };
426
427 let main_stmt = SelectStatement {
428 from_subquery: Some(Box::new(subquery)),
429 from_alias: Some("sub".to_string()),
430 ..Default::default()
431 };
432
433 let analysis = analyzer.analyze(&main_stmt);
434 assert_eq!(analysis.total_count, 1);
435 assert_eq!(
436 analysis.subqueries[0].location,
437 SubqueryLocation::FromClause
438 );
439 assert_eq!(
440 analysis.subqueries[0].subquery_type,
441 SubqueryType::DerivedTable
442 );
443 assert!(!analysis.subqueries[0].is_correlated);
444 }
445
446 #[test]
447 fn test_analysis_report_format() {
448 let analysis = CorrelationAnalysis {
449 subqueries: vec![],
450 total_count: 0,
451 correlated_count: 0,
452 non_correlated_count: 0,
453 };
454
455 let report = analysis.report();
456 assert!(report.contains("Subquery Analysis"));
457 assert!(report.contains("No subqueries detected"));
458 }
459}