1use std::collections::HashMap;
24
25use crate::ast::*;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33pub enum ScopeType {
34 Root,
36 Subquery,
38 DerivedTable,
40 Cte,
42 Union,
44 Udtf,
46}
47
48#[derive(Debug, Clone)]
55pub enum Source {
56 Table(TableRef),
58 Scope(Box<Scope>),
60}
61
62#[derive(Debug, Clone, PartialEq)]
68pub struct ColumnRef {
69 pub table: Option<String>,
71 pub name: String,
73}
74
75#[derive(Debug, Clone)]
85pub struct Scope {
86 pub scope_type: ScopeType,
88
89 pub sources: HashMap<String, Source>,
93
94 pub columns: Vec<ColumnRef>,
97
98 pub external_columns: Vec<ColumnRef>,
100
101 pub derived_table_scopes: Vec<Scope>,
103
104 pub subquery_scopes: Vec<Scope>,
106
107 pub union_scopes: Vec<Scope>,
109
110 pub cte_scopes: Vec<Scope>,
112
113 pub selected_sources: HashMap<String, Source>,
116
117 pub is_correlated: bool,
119
120 expression: Option<ScopeExpression>,
123}
124
125#[derive(Debug, Clone)]
127#[allow(dead_code)]
128enum ScopeExpression {
129 Statement(Statement),
130}
131
132impl Scope {
133 fn new(scope_type: ScopeType) -> Self {
134 Self {
135 scope_type,
136 sources: HashMap::new(),
137 columns: Vec::new(),
138 external_columns: Vec::new(),
139 derived_table_scopes: Vec::new(),
140 subquery_scopes: Vec::new(),
141 union_scopes: Vec::new(),
142 cte_scopes: Vec::new(),
143 selected_sources: HashMap::new(),
144 is_correlated: false,
145 expression: None,
146 }
147 }
148
149 #[must_use]
151 pub fn source_names(&self) -> Vec<&str> {
152 self.sources.keys().map(String::as_str).collect()
153 }
154
155 #[must_use]
158 pub fn child_scopes(&self) -> Vec<&Scope> {
159 let mut children: Vec<&Scope> = Vec::new();
160 children.extend(self.derived_table_scopes.iter());
161 children.extend(self.subquery_scopes.iter());
162 children.extend(self.union_scopes.iter());
163 children.extend(self.cte_scopes.iter());
164 children
165 }
166
167 pub fn walk<F>(&self, visitor: &mut F)
169 where
170 F: FnMut(&Scope),
171 {
172 visitor(self);
173 for child in self.child_scopes() {
174 child.walk(visitor);
175 }
176 }
177}
178
179#[must_use]
188pub fn build_scope(statement: &Statement) -> Scope {
189 let mut scope = Scope::new(ScopeType::Root);
190 scope.expression = Some(ScopeExpression::Statement(statement.clone()));
191 build_scope_inner(statement, &mut scope, ScopeType::Root);
192 resolve_selected_sources(&mut scope);
193 detect_correlation(&mut scope, &[]);
194 scope
195}
196
197#[must_use]
202pub fn find_all_in_scope<'a, F>(scope: &'a Scope, predicate: &F) -> Vec<&'a ColumnRef>
203where
204 F: Fn(&ColumnRef) -> bool,
205{
206 scope.columns.iter().filter(|c| predicate(c)).collect()
207}
208
209fn build_scope_inner(statement: &Statement, scope: &mut Scope, _scope_type: ScopeType) {
214 match statement {
215 Statement::Select(sel) => build_select_scope(sel, scope),
216 Statement::SetOperation(set_op) => build_set_operation_scope(set_op, scope),
217 Statement::CreateView(cv) => {
218 build_scope_inner(&cv.query, scope, ScopeType::Root);
220 }
221 Statement::Insert(ins) => {
222 if let InsertSource::Query(q) = &ins.source {
223 let mut sub = Scope::new(ScopeType::Subquery);
224 build_scope_inner(q, &mut sub, ScopeType::Subquery);
225 resolve_selected_sources(&mut sub);
226 scope.subquery_scopes.push(sub);
227 }
228 }
229 Statement::Delete(del) => {
230 if let Some(wh) = &del.where_clause {
232 collect_columns_from_expr(wh, scope);
233 }
234 }
235 Statement::Update(upd) => {
236 for (_col, expr) in &upd.assignments {
238 collect_columns_from_expr(expr, scope);
239 }
240 if let Some(wh) = &upd.where_clause {
241 collect_columns_from_expr(wh, scope);
242 }
243 }
244 Statement::Explain(expl) => {
245 build_scope_inner(&expl.statement, scope, _scope_type);
246 }
247 _ => {}
249 }
250}
251
252fn build_select_scope(sel: &SelectStatement, scope: &mut Scope) {
254 for cte in &sel.ctes {
256 let mut cte_scope = Scope::new(ScopeType::Cte);
257 cte_scope.expression = Some(ScopeExpression::Statement(*cte.query.clone()));
258 build_scope_inner(&cte.query, &mut cte_scope, ScopeType::Cte);
259 resolve_selected_sources(&mut cte_scope);
260
261 scope
263 .sources
264 .insert(cte.name.clone(), Source::Scope(Box::new(cte_scope.clone())));
265 scope.cte_scopes.push(cte_scope);
266 }
267
268 if let Some(from) = &sel.from {
270 process_table_source(&from.source, scope);
271 }
272
273 for join in &sel.joins {
275 process_table_source(&join.table, scope);
276 if let Some(on) = &join.on {
277 collect_columns_from_expr(on, scope);
278 }
279 }
280
281 for item in &sel.columns {
283 match item {
284 SelectItem::Expr { expr, .. } => {
285 collect_columns_from_expr(expr, scope);
286 collect_subqueries_from_expr(expr, scope);
287 }
288 SelectItem::QualifiedWildcard { table } => {
289 scope.columns.push(ColumnRef {
292 table: Some(table.clone()),
293 name: "*".to_string(),
294 });
295 }
296 SelectItem::Wildcard => {}
297 }
298 }
299
300 if let Some(wh) = &sel.where_clause {
302 collect_columns_from_expr(wh, scope);
303 collect_subqueries_from_expr(wh, scope);
304 }
305
306 for expr in &sel.group_by {
308 collect_columns_from_expr(expr, scope);
309 }
310
311 if let Some(having) = &sel.having {
313 collect_columns_from_expr(having, scope);
314 collect_subqueries_from_expr(having, scope);
315 }
316
317 for item in &sel.order_by {
319 collect_columns_from_expr(&item.expr, scope);
320 }
321
322 if let Some(qualify) = &sel.qualify {
324 collect_columns_from_expr(qualify, scope);
325 collect_subqueries_from_expr(qualify, scope);
326 }
327}
328
329fn build_set_operation_scope(set_op: &SetOperationStatement, scope: &mut Scope) {
331 let mut left_scope = Scope::new(ScopeType::Union);
333 build_scope_inner(&set_op.left, &mut left_scope, ScopeType::Union);
334 resolve_selected_sources(&mut left_scope);
335 scope.union_scopes.push(left_scope);
336
337 let mut right_scope = Scope::new(ScopeType::Union);
338 build_scope_inner(&set_op.right, &mut right_scope, ScopeType::Union);
339 resolve_selected_sources(&mut right_scope);
340 scope.union_scopes.push(right_scope);
341
342 for item in &set_op.order_by {
344 collect_columns_from_expr(&item.expr, scope);
345 }
346}
347
348fn process_table_source(source: &TableSource, scope: &mut Scope) {
350 match source {
351 TableSource::Table(table_ref) => {
352 let key = table_ref
353 .alias
354 .as_deref()
355 .unwrap_or(&table_ref.name)
356 .to_string();
357 scope.sources.insert(key, Source::Table(table_ref.clone()));
358 }
359 TableSource::Subquery { query, alias, .. } => {
360 let mut dt_scope = Scope::new(ScopeType::DerivedTable);
361 dt_scope.expression = Some(ScopeExpression::Statement(*query.clone()));
362 build_scope_inner(query, &mut dt_scope, ScopeType::DerivedTable);
363 resolve_selected_sources(&mut dt_scope);
364
365 if let Some(alias) = alias {
366 scope
367 .sources
368 .insert(alias.clone(), Source::Scope(Box::new(dt_scope.clone())));
369 }
370 scope.derived_table_scopes.push(dt_scope);
371 }
372 TableSource::TableFunction { alias, .. } => {
373 if let Some(alias) = alias {
374 scope.sources.insert(
376 alias.clone(),
377 Source::Table(TableRef {
378 catalog: None,
379 schema: None,
380 name: alias.clone(),
381 alias: None,
382 name_quote_style: QuoteStyle::None,
383 alias_quote_style: QuoteStyle::None,
384 }),
385 );
386 }
387 }
388 TableSource::Lateral { source } => {
389 process_table_source(source, scope);
390 }
391 TableSource::Pivot { source, alias, .. } | TableSource::Unpivot { source, alias, .. } => {
392 process_table_source(source, scope);
393 if let Some(alias) = alias {
394 scope.sources.insert(
395 alias.clone(),
396 Source::Table(TableRef {
397 catalog: None,
398 schema: None,
399 name: alias.clone(),
400 alias: None,
401 name_quote_style: QuoteStyle::None,
402 alias_quote_style: QuoteStyle::None,
403 }),
404 );
405 }
406 }
407 TableSource::Unnest { alias, .. } => {
408 if let Some(alias) = alias {
409 scope.sources.insert(
410 alias.clone(),
411 Source::Table(TableRef {
412 catalog: None,
413 schema: None,
414 name: alias.clone(),
415 alias: None,
416 name_quote_style: QuoteStyle::None,
417 alias_quote_style: QuoteStyle::None,
418 }),
419 );
420 }
421 }
422 }
423}
424
425fn collect_columns_from_expr(expr: &Expr, scope: &mut Scope) {
432 expr.walk(&mut |e| {
433 match e {
434 Expr::Column { table, name, .. } => {
435 scope.columns.push(ColumnRef {
436 table: table.clone(),
437 name: name.clone(),
438 });
439 true
440 }
441 Expr::Subquery(_) | Expr::Exists { .. } | Expr::InSubquery { .. } => false,
443 _ => true,
444 }
445 });
446}
447
448fn collect_subqueries_from_expr(expr: &Expr, scope: &mut Scope) {
450 match expr {
451 Expr::Subquery(stmt) => {
452 let mut sub = Scope::new(ScopeType::Subquery);
453 sub.expression = Some(ScopeExpression::Statement(*stmt.clone()));
454 build_scope_inner(stmt, &mut sub, ScopeType::Subquery);
455 resolve_selected_sources(&mut sub);
456 scope.subquery_scopes.push(sub);
457 }
458 Expr::Exists { subquery, .. } => {
459 let mut sub = Scope::new(ScopeType::Subquery);
460 sub.expression = Some(ScopeExpression::Statement(*subquery.clone()));
461 build_scope_inner(subquery, &mut sub, ScopeType::Subquery);
462 resolve_selected_sources(&mut sub);
463 scope.subquery_scopes.push(sub);
464 }
465 Expr::InSubquery {
466 expr: left,
467 subquery,
468 ..
469 } => {
470 collect_columns_from_expr(left, scope);
472
473 let mut sub = Scope::new(ScopeType::Subquery);
474 sub.expression = Some(ScopeExpression::Statement(*subquery.clone()));
475 build_scope_inner(subquery, &mut sub, ScopeType::Subquery);
476 resolve_selected_sources(&mut sub);
477 scope.subquery_scopes.push(sub);
478 }
479 _ => {
480 walk_expr_for_subqueries(expr, scope);
482 }
483 }
484}
485
486fn walk_expr_for_subqueries(expr: &Expr, scope: &mut Scope) {
489 match expr {
490 Expr::BinaryOp { left, right, .. } => {
491 collect_subqueries_from_expr(left, scope);
492 collect_subqueries_from_expr(right, scope);
493 }
494 Expr::UnaryOp { expr: inner, .. } => {
495 collect_subqueries_from_expr(inner, scope);
496 }
497 Expr::Function { args, filter, .. } => {
498 for arg in args {
499 collect_subqueries_from_expr(arg, scope);
500 }
501 if let Some(f) = filter {
502 collect_subqueries_from_expr(f, scope);
503 }
504 }
505 Expr::Nested(inner) => {
506 collect_subqueries_from_expr(inner, scope);
507 }
508 Expr::Case {
509 operand,
510 when_clauses,
511 else_clause,
512 } => {
513 if let Some(op) = operand {
514 collect_subqueries_from_expr(op, scope);
515 }
516 for (cond, result) in when_clauses {
517 collect_subqueries_from_expr(cond, scope);
518 collect_subqueries_from_expr(result, scope);
519 }
520 if let Some(el) = else_clause {
521 collect_subqueries_from_expr(el, scope);
522 }
523 }
524 Expr::Between {
525 expr: inner,
526 low,
527 high,
528 ..
529 } => {
530 collect_subqueries_from_expr(inner, scope);
531 collect_subqueries_from_expr(low, scope);
532 collect_subqueries_from_expr(high, scope);
533 }
534 Expr::InList {
535 expr: inner, list, ..
536 } => {
537 collect_subqueries_from_expr(inner, scope);
538 for item in list {
539 collect_subqueries_from_expr(item, scope);
540 }
541 }
542 Expr::Cast { expr: inner, .. } | Expr::TryCast { expr: inner, .. } => {
543 collect_subqueries_from_expr(inner, scope);
544 }
545 Expr::Coalesce(items) | Expr::ArrayLiteral(items) | Expr::Tuple(items) => {
546 for item in items {
547 collect_subqueries_from_expr(item, scope);
548 }
549 }
550 Expr::If {
551 condition,
552 true_val,
553 false_val,
554 } => {
555 collect_subqueries_from_expr(condition, scope);
556 collect_subqueries_from_expr(true_val, scope);
557 if let Some(fv) = false_val {
558 collect_subqueries_from_expr(fv, scope);
559 }
560 }
561 Expr::IsNull { expr: inner, .. } | Expr::IsBool { expr: inner, .. } => {
562 collect_subqueries_from_expr(inner, scope);
563 }
564 Expr::Like {
565 expr: inner,
566 pattern,
567 ..
568 }
569 | Expr::ILike {
570 expr: inner,
571 pattern,
572 ..
573 } => {
574 collect_subqueries_from_expr(inner, scope);
575 collect_subqueries_from_expr(pattern, scope);
576 }
577 Expr::Alias { expr: inner, .. } | Expr::Collate { expr: inner, .. } => {
578 collect_subqueries_from_expr(inner, scope);
579 }
580 Expr::NullIf {
581 expr: inner,
582 r#else,
583 } => {
584 collect_subqueries_from_expr(inner, scope);
585 collect_subqueries_from_expr(r#else, scope);
586 }
587 Expr::AnyOp {
588 expr: inner, right, ..
589 }
590 | Expr::AllOp {
591 expr: inner, right, ..
592 } => {
593 collect_subqueries_from_expr(inner, scope);
594 collect_subqueries_from_expr(right, scope);
595 }
596 Expr::ArrayIndex { expr: inner, index } => {
597 collect_subqueries_from_expr(inner, scope);
598 collect_subqueries_from_expr(index, scope);
599 }
600 Expr::JsonAccess {
601 expr: inner, path, ..
602 } => {
603 collect_subqueries_from_expr(inner, scope);
604 collect_subqueries_from_expr(path, scope);
605 }
606 Expr::Lambda { body, .. } => {
607 collect_subqueries_from_expr(body, scope);
608 }
609 Expr::Extract { expr: inner, .. } | Expr::Interval { value: inner, .. } => {
610 collect_subqueries_from_expr(inner, scope);
611 }
612 Expr::Cube { exprs } | Expr::Rollup { exprs } | Expr::GroupingSets { sets: exprs } => {
613 for item in exprs {
614 collect_subqueries_from_expr(item, scope);
615 }
616 }
617 _ => {}
619 }
620}
621
622#[allow(clippy::collapsible_if)]
629fn resolve_selected_sources(scope: &mut Scope) {
630 for col in &scope.columns {
631 if let Some(table) = &col.table {
632 if let Some(source) = scope.sources.get(table) {
633 scope
634 .selected_sources
635 .entry(table.clone())
636 .or_insert_with(|| source.clone());
637 }
638 }
639 }
640}
641
642fn detect_correlation(scope: &mut Scope, outer_source_names: &[String]) {
650 let mut visible: Vec<String> = outer_source_names.to_vec();
652 visible.extend(scope.sources.keys().cloned());
653
654 detect_correlation_in_children(&mut scope.subquery_scopes, &visible);
656 detect_correlation_in_children(&mut scope.derived_table_scopes, &visible);
657 detect_correlation_in_children(&mut scope.union_scopes, &visible);
658 detect_correlation_in_children(&mut scope.cte_scopes, &visible);
659}
660
661#[allow(clippy::collapsible_if)]
662fn detect_correlation_in_children(children: &mut [Scope], outer_names: &[String]) {
663 for child in children.iter_mut() {
664 for col in &child.columns {
667 if let Some(table) = &col.table {
668 if outer_names.contains(table) && !child.sources.contains_key(table) {
669 child.external_columns.push(col.clone());
670 child.is_correlated = true;
671 }
672 }
673 }
674
675 detect_correlation(child, outer_names);
677 }
678}
679
680#[cfg(test)]
685mod tests {
686 use super::*;
687 use crate::dialects::Dialect;
688 use crate::parser::parse;
689
690 fn scope_for(sql: &str) -> Scope {
692 let ast = parse(sql, Dialect::Ansi).unwrap();
693 build_scope(&ast)
694 }
695
696 #[test]
699 fn test_simple_select() {
700 let scope = scope_for("SELECT a, b FROM t WHERE a > 1");
701 assert_eq!(scope.scope_type, ScopeType::Root);
702 assert!(scope.sources.contains_key("t"));
703 assert!(scope.columns.len() >= 2);
705 assert!(scope.external_columns.is_empty());
706 assert!(!scope.is_correlated);
707 }
708
709 #[test]
710 fn test_aliased_table() {
711 let scope = scope_for("SELECT t1.x FROM my_table t1");
712 assert!(scope.sources.contains_key("t1"));
713 assert!(!scope.sources.contains_key("my_table"));
714 }
715
716 #[test]
719 fn test_join_sources() {
720 let scope = scope_for("SELECT a.id, b.val FROM alpha a JOIN beta b ON a.id = b.id");
721 assert!(scope.sources.contains_key("a"));
722 assert!(scope.sources.contains_key("b"));
723 let on_cols: Vec<_> = scope.columns.iter().filter(|c| c.name == "id").collect();
725 assert!(on_cols.len() >= 2); }
727
728 #[test]
731 fn test_derived_table() {
732 let scope = scope_for("SELECT sub.x FROM (SELECT a AS x FROM t) sub");
733 assert!(scope.sources.contains_key("sub"));
734 assert_eq!(scope.derived_table_scopes.len(), 1);
735
736 let dt = &scope.derived_table_scopes[0];
737 assert_eq!(dt.scope_type, ScopeType::DerivedTable);
738 assert!(dt.sources.contains_key("t"));
739 }
740
741 #[test]
744 fn test_cte_scope() {
745 let scope = scope_for("WITH cte AS (SELECT id FROM t) SELECT cte.id FROM cte");
746 assert!(scope.sources.contains_key("cte"));
747 assert_eq!(scope.cte_scopes.len(), 1);
748
749 let cte = &scope.cte_scopes[0];
750 assert_eq!(cte.scope_type, ScopeType::Cte);
751 assert!(cte.sources.contains_key("t"));
752 }
753
754 #[test]
755 fn test_multiple_ctes() {
756 let scope = scope_for(
757 "WITH a AS (SELECT 1 AS x), b AS (SELECT 2 AS y) \
758 SELECT a.x, b.y FROM a, b",
759 );
760 assert_eq!(scope.cte_scopes.len(), 2);
761 assert!(scope.sources.contains_key("a"));
762 assert!(scope.sources.contains_key("b"));
763 }
764
765 #[test]
768 fn test_union_scopes() {
769 let scope = scope_for("SELECT a FROM t1 UNION ALL SELECT b FROM t2");
770 assert_eq!(scope.union_scopes.len(), 2);
771
772 let left = &scope.union_scopes[0];
773 assert_eq!(left.scope_type, ScopeType::Union);
774 assert!(left.sources.contains_key("t1"));
775
776 let right = &scope.union_scopes[1];
777 assert!(right.sources.contains_key("t2"));
778 }
779
780 #[test]
783 fn test_scalar_subquery() {
784 let scope = scope_for("SELECT (SELECT MAX(x) FROM t2) AS mx FROM t1");
785 assert_eq!(scope.subquery_scopes.len(), 1);
786 let sub = &scope.subquery_scopes[0];
787 assert_eq!(sub.scope_type, ScopeType::Subquery);
788 assert!(sub.sources.contains_key("t2"));
789 }
790
791 #[test]
794 fn test_exists_subquery() {
795 let scope =
796 scope_for("SELECT a FROM t1 WHERE EXISTS (SELECT 1 FROM t2 WHERE t2.id = t1.id)");
797 assert_eq!(scope.subquery_scopes.len(), 1);
798 let sub = &scope.subquery_scopes[0];
799 assert!(sub.sources.contains_key("t2"));
800 assert!(sub.is_correlated);
802 assert!(!sub.external_columns.is_empty());
803 let ext = &sub.external_columns[0];
804 assert_eq!(ext.table.as_deref(), Some("t1"));
805 assert_eq!(ext.name, "id");
806 }
807
808 #[test]
811 fn test_in_subquery() {
812 let scope = scope_for("SELECT a FROM t1 WHERE a IN (SELECT b FROM t2)");
813 assert_eq!(scope.subquery_scopes.len(), 1);
814 let sub = &scope.subquery_scopes[0];
815 assert!(sub.sources.contains_key("t2"));
816 assert!(!sub.is_correlated);
818 }
819
820 #[test]
823 fn test_correlated_subquery() {
824 let scope =
825 scope_for("SELECT a FROM t1 WHERE a = (SELECT MAX(b) FROM t2 WHERE t2.fk = t1.id)");
826 assert_eq!(scope.subquery_scopes.len(), 1);
827 let sub = &scope.subquery_scopes[0];
828 assert!(sub.is_correlated);
829 assert!(
830 sub.external_columns
831 .iter()
832 .any(|c| c.table.as_deref() == Some("t1"))
833 );
834 }
835
836 #[test]
839 fn test_nested_subqueries() {
840 let scope = scope_for(
841 "SELECT a FROM t1 WHERE a IN (SELECT b FROM t2 WHERE b > (SELECT MIN(c) FROM t3))",
842 );
843 assert_eq!(scope.subquery_scopes.len(), 1);
845
846 let in_sub = &scope.subquery_scopes[0];
847 assert!(in_sub.sources.contains_key("t2"));
848 assert_eq!(in_sub.subquery_scopes.len(), 1);
850 let inner = &in_sub.subquery_scopes[0];
851 assert!(inner.sources.contains_key("t3"));
852 }
853
854 #[test]
857 fn test_selected_sources() {
858 let scope = scope_for("SELECT a.x FROM alpha a JOIN beta b ON a.id = b.id");
859 assert!(scope.selected_sources.contains_key("a"));
861 }
862
863 #[test]
866 fn test_find_all_in_scope() {
867 let scope = scope_for("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
868 let t_cols = find_all_in_scope(&scope, &|c| c.table.as_deref() == Some("t"));
869 assert!(t_cols.len() >= 3);
871 }
872
873 #[test]
876 fn test_scope_walk() {
877 let scope = scope_for(
878 "WITH cte AS (SELECT 1 AS a) \
879 SELECT * FROM cte WHERE EXISTS (SELECT 1 FROM t)",
880 );
881 let mut count = 0;
882 scope.walk(&mut |_| count += 1);
883 assert!(count >= 3);
885 }
886
887 #[test]
890 fn test_qualified_wildcard() {
891 let scope = scope_for("SELECT t.* FROM t");
892 assert!(
893 scope
894 .columns
895 .iter()
896 .any(|c| c.table.as_deref() == Some("t") && c.name == "*")
897 );
898 assert!(scope.selected_sources.contains_key("t"));
899 }
900}