1use std::collections::HashMap;
24
25use crate::ast::*;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33pub enum ScopeType {
34 Root,
36 Subquery,
38 DerivedTable,
40 Cte,
42 Union,
44 Udtf,
46}
47
48#[derive(Debug, Clone)]
55pub enum Source {
56 Table(TableRef),
58 Scope(Box<Scope>),
60}
61
62#[derive(Debug, Clone, PartialEq)]
68pub struct ColumnRef {
69 pub table: Option<String>,
71 pub name: String,
73}
74
75#[derive(Debug, Clone)]
85pub struct Scope {
86 pub scope_type: ScopeType,
88
89 pub sources: HashMap<String, Source>,
93
94 pub columns: Vec<ColumnRef>,
97
98 pub external_columns: Vec<ColumnRef>,
100
101 pub derived_table_scopes: Vec<Scope>,
103
104 pub subquery_scopes: Vec<Scope>,
106
107 pub union_scopes: Vec<Scope>,
109
110 pub cte_scopes: Vec<Scope>,
112
113 pub selected_sources: HashMap<String, Source>,
116
117 pub is_correlated: bool,
119
120 expression: Option<ScopeExpression>,
123}
124
125#[derive(Debug, Clone)]
127#[allow(dead_code)]
128enum ScopeExpression {
129 Statement(Statement),
130}
131
132impl Scope {
133 fn new(scope_type: ScopeType) -> Self {
134 Self {
135 scope_type,
136 sources: HashMap::new(),
137 columns: Vec::new(),
138 external_columns: Vec::new(),
139 derived_table_scopes: Vec::new(),
140 subquery_scopes: Vec::new(),
141 union_scopes: Vec::new(),
142 cte_scopes: Vec::new(),
143 selected_sources: HashMap::new(),
144 is_correlated: false,
145 expression: None,
146 }
147 }
148
149 #[must_use]
151 pub fn source_names(&self) -> Vec<&str> {
152 self.sources.keys().map(String::as_str).collect()
153 }
154
155 #[must_use]
158 pub fn child_scopes(&self) -> Vec<&Scope> {
159 let mut children: Vec<&Scope> = Vec::new();
160 children.extend(self.derived_table_scopes.iter());
161 children.extend(self.subquery_scopes.iter());
162 children.extend(self.union_scopes.iter());
163 children.extend(self.cte_scopes.iter());
164 children
165 }
166
167 pub fn walk<F>(&self, visitor: &mut F)
169 where
170 F: FnMut(&Scope),
171 {
172 visitor(self);
173 for child in self.child_scopes() {
174 child.walk(visitor);
175 }
176 }
177}
178
179#[must_use]
188pub fn build_scope(statement: &Statement) -> Scope {
189 let mut scope = Scope::new(ScopeType::Root);
190 scope.expression = Some(ScopeExpression::Statement(statement.clone()));
191 build_scope_inner(statement, &mut scope, ScopeType::Root);
192 resolve_selected_sources(&mut scope);
193 detect_correlation(&mut scope, &[]);
194 scope
195}
196
197#[must_use]
202pub fn find_all_in_scope<'a, F>(scope: &'a Scope, predicate: &F) -> Vec<&'a ColumnRef>
203where
204 F: Fn(&ColumnRef) -> bool,
205{
206 scope.columns.iter().filter(|c| predicate(c)).collect()
207}
208
209fn build_scope_inner(statement: &Statement, scope: &mut Scope, _scope_type: ScopeType) {
214 match statement {
215 Statement::Select(sel) => build_select_scope(sel, scope),
216 Statement::SetOperation(set_op) => build_set_operation_scope(set_op, scope),
217 Statement::CreateView(cv) => {
218 build_scope_inner(&cv.query, scope, ScopeType::Root);
220 }
221 Statement::Insert(ins) => {
222 if let InsertSource::Query(q) = &ins.source {
223 let mut sub = Scope::new(ScopeType::Subquery);
224 build_scope_inner(q, &mut sub, ScopeType::Subquery);
225 resolve_selected_sources(&mut sub);
226 scope.subquery_scopes.push(sub);
227 }
228 }
229 Statement::Delete(del) => {
230 if let Some(wh) = &del.where_clause {
232 collect_columns_from_expr(wh, scope);
233 }
234 }
235 Statement::Update(upd) => {
236 for (_col, expr) in &upd.assignments {
238 collect_columns_from_expr(expr, scope);
239 }
240 if let Some(wh) = &upd.where_clause {
241 collect_columns_from_expr(wh, scope);
242 }
243 }
244 Statement::Explain(expl) => {
245 build_scope_inner(&expl.statement, scope, _scope_type);
246 }
247 _ => {}
249 }
250}
251
252fn build_select_scope(sel: &SelectStatement, scope: &mut Scope) {
254 for cte in &sel.ctes {
256 let mut cte_scope = Scope::new(ScopeType::Cte);
257 cte_scope.expression = Some(ScopeExpression::Statement(*cte.query.clone()));
258 build_scope_inner(&cte.query, &mut cte_scope, ScopeType::Cte);
259 resolve_selected_sources(&mut cte_scope);
260
261 scope
263 .sources
264 .insert(cte.name.clone(), Source::Scope(Box::new(cte_scope.clone())));
265 scope.cte_scopes.push(cte_scope);
266 }
267
268 if let Some(from) = &sel.from {
270 process_table_source(&from.source, scope);
271 }
272
273 for join in &sel.joins {
275 process_table_source(&join.table, scope);
276 if let Some(on) = &join.on {
277 collect_columns_from_expr(on, scope);
278 }
279 }
280
281 for item in &sel.columns {
283 match item {
284 SelectItem::Expr { expr, .. } => {
285 collect_columns_from_expr(expr, scope);
286 collect_subqueries_from_expr(expr, scope);
287 }
288 SelectItem::QualifiedWildcard { table } => {
289 scope.columns.push(ColumnRef {
292 table: Some(table.clone()),
293 name: "*".to_string(),
294 });
295 }
296 SelectItem::Wildcard => {}
297 }
298 }
299
300 if let Some(wh) = &sel.where_clause {
302 collect_columns_from_expr(wh, scope);
303 collect_subqueries_from_expr(wh, scope);
304 }
305
306 for expr in &sel.group_by {
308 collect_columns_from_expr(expr, scope);
309 }
310
311 if let Some(having) = &sel.having {
313 collect_columns_from_expr(having, scope);
314 collect_subqueries_from_expr(having, scope);
315 }
316
317 for item in &sel.order_by {
319 collect_columns_from_expr(&item.expr, scope);
320 }
321
322 if let Some(qualify) = &sel.qualify {
324 collect_columns_from_expr(qualify, scope);
325 collect_subqueries_from_expr(qualify, scope);
326 }
327}
328
329fn build_set_operation_scope(set_op: &SetOperationStatement, scope: &mut Scope) {
331 let mut left_scope = Scope::new(ScopeType::Union);
333 build_scope_inner(&set_op.left, &mut left_scope, ScopeType::Union);
334 resolve_selected_sources(&mut left_scope);
335 scope.union_scopes.push(left_scope);
336
337 let mut right_scope = Scope::new(ScopeType::Union);
338 build_scope_inner(&set_op.right, &mut right_scope, ScopeType::Union);
339 resolve_selected_sources(&mut right_scope);
340 scope.union_scopes.push(right_scope);
341
342 for item in &set_op.order_by {
344 collect_columns_from_expr(&item.expr, scope);
345 }
346}
347
348fn process_table_source(source: &TableSource, scope: &mut Scope) {
350 match source {
351 TableSource::Table(table_ref) => {
352 let key = table_ref
353 .alias
354 .as_deref()
355 .unwrap_or(&table_ref.name)
356 .to_string();
357 scope.sources.insert(key, Source::Table(table_ref.clone()));
358 }
359 TableSource::Subquery { query, alias } => {
360 let mut dt_scope = Scope::new(ScopeType::DerivedTable);
361 dt_scope.expression = Some(ScopeExpression::Statement(*query.clone()));
362 build_scope_inner(query, &mut dt_scope, ScopeType::DerivedTable);
363 resolve_selected_sources(&mut dt_scope);
364
365 if let Some(alias) = alias {
366 scope
367 .sources
368 .insert(alias.clone(), Source::Scope(Box::new(dt_scope.clone())));
369 }
370 scope.derived_table_scopes.push(dt_scope);
371 }
372 TableSource::TableFunction { alias, .. } => {
373 if let Some(alias) = alias {
374 scope.sources.insert(
376 alias.clone(),
377 Source::Table(TableRef {
378 catalog: None,
379 schema: None,
380 name: alias.clone(),
381 alias: None,
382 name_quote_style: QuoteStyle::None,
383 }),
384 );
385 }
386 }
387 TableSource::Lateral { source } => {
388 process_table_source(source, scope);
389 }
390 TableSource::Unnest { alias, .. } => {
391 if let Some(alias) = alias {
392 scope.sources.insert(
393 alias.clone(),
394 Source::Table(TableRef {
395 catalog: None,
396 schema: None,
397 name: alias.clone(),
398 alias: None,
399 name_quote_style: QuoteStyle::None,
400 }),
401 );
402 }
403 }
404 }
405}
406
407fn collect_columns_from_expr(expr: &Expr, scope: &mut Scope) {
414 expr.walk(&mut |e| {
415 match e {
416 Expr::Column { table, name, .. } => {
417 scope.columns.push(ColumnRef {
418 table: table.clone(),
419 name: name.clone(),
420 });
421 true
422 }
423 Expr::Subquery(_) | Expr::Exists { .. } | Expr::InSubquery { .. } => false,
425 _ => true,
426 }
427 });
428}
429
430fn collect_subqueries_from_expr(expr: &Expr, scope: &mut Scope) {
432 match expr {
433 Expr::Subquery(stmt) => {
434 let mut sub = Scope::new(ScopeType::Subquery);
435 sub.expression = Some(ScopeExpression::Statement(*stmt.clone()));
436 build_scope_inner(stmt, &mut sub, ScopeType::Subquery);
437 resolve_selected_sources(&mut sub);
438 scope.subquery_scopes.push(sub);
439 }
440 Expr::Exists { subquery, .. } => {
441 let mut sub = Scope::new(ScopeType::Subquery);
442 sub.expression = Some(ScopeExpression::Statement(*subquery.clone()));
443 build_scope_inner(subquery, &mut sub, ScopeType::Subquery);
444 resolve_selected_sources(&mut sub);
445 scope.subquery_scopes.push(sub);
446 }
447 Expr::InSubquery {
448 expr: left,
449 subquery,
450 ..
451 } => {
452 collect_columns_from_expr(left, scope);
454
455 let mut sub = Scope::new(ScopeType::Subquery);
456 sub.expression = Some(ScopeExpression::Statement(*subquery.clone()));
457 build_scope_inner(subquery, &mut sub, ScopeType::Subquery);
458 resolve_selected_sources(&mut sub);
459 scope.subquery_scopes.push(sub);
460 }
461 _ => {
462 walk_expr_for_subqueries(expr, scope);
464 }
465 }
466}
467
468fn walk_expr_for_subqueries(expr: &Expr, scope: &mut Scope) {
471 match expr {
472 Expr::BinaryOp { left, right, .. } => {
473 collect_subqueries_from_expr(left, scope);
474 collect_subqueries_from_expr(right, scope);
475 }
476 Expr::UnaryOp { expr: inner, .. } => {
477 collect_subqueries_from_expr(inner, scope);
478 }
479 Expr::Function { args, filter, .. } => {
480 for arg in args {
481 collect_subqueries_from_expr(arg, scope);
482 }
483 if let Some(f) = filter {
484 collect_subqueries_from_expr(f, scope);
485 }
486 }
487 Expr::Nested(inner) => {
488 collect_subqueries_from_expr(inner, scope);
489 }
490 Expr::Case {
491 operand,
492 when_clauses,
493 else_clause,
494 } => {
495 if let Some(op) = operand {
496 collect_subqueries_from_expr(op, scope);
497 }
498 for (cond, result) in when_clauses {
499 collect_subqueries_from_expr(cond, scope);
500 collect_subqueries_from_expr(result, scope);
501 }
502 if let Some(el) = else_clause {
503 collect_subqueries_from_expr(el, scope);
504 }
505 }
506 Expr::Between {
507 expr: inner,
508 low,
509 high,
510 ..
511 } => {
512 collect_subqueries_from_expr(inner, scope);
513 collect_subqueries_from_expr(low, scope);
514 collect_subqueries_from_expr(high, scope);
515 }
516 Expr::InList {
517 expr: inner, list, ..
518 } => {
519 collect_subqueries_from_expr(inner, scope);
520 for item in list {
521 collect_subqueries_from_expr(item, scope);
522 }
523 }
524 Expr::Cast { expr: inner, .. } | Expr::TryCast { expr: inner, .. } => {
525 collect_subqueries_from_expr(inner, scope);
526 }
527 Expr::Coalesce(items) | Expr::ArrayLiteral(items) | Expr::Tuple(items) => {
528 for item in items {
529 collect_subqueries_from_expr(item, scope);
530 }
531 }
532 Expr::If {
533 condition,
534 true_val,
535 false_val,
536 } => {
537 collect_subqueries_from_expr(condition, scope);
538 collect_subqueries_from_expr(true_val, scope);
539 if let Some(fv) = false_val {
540 collect_subqueries_from_expr(fv, scope);
541 }
542 }
543 Expr::IsNull { expr: inner, .. } | Expr::IsBool { expr: inner, .. } => {
544 collect_subqueries_from_expr(inner, scope);
545 }
546 Expr::Like {
547 expr: inner,
548 pattern,
549 ..
550 }
551 | Expr::ILike {
552 expr: inner,
553 pattern,
554 ..
555 } => {
556 collect_subqueries_from_expr(inner, scope);
557 collect_subqueries_from_expr(pattern, scope);
558 }
559 Expr::Alias { expr: inner, .. } | Expr::Collate { expr: inner, .. } => {
560 collect_subqueries_from_expr(inner, scope);
561 }
562 Expr::NullIf {
563 expr: inner,
564 r#else,
565 } => {
566 collect_subqueries_from_expr(inner, scope);
567 collect_subqueries_from_expr(r#else, scope);
568 }
569 Expr::AnyOp {
570 expr: inner, right, ..
571 }
572 | Expr::AllOp {
573 expr: inner, right, ..
574 } => {
575 collect_subqueries_from_expr(inner, scope);
576 collect_subqueries_from_expr(right, scope);
577 }
578 Expr::ArrayIndex { expr: inner, index } => {
579 collect_subqueries_from_expr(inner, scope);
580 collect_subqueries_from_expr(index, scope);
581 }
582 Expr::JsonAccess {
583 expr: inner, path, ..
584 } => {
585 collect_subqueries_from_expr(inner, scope);
586 collect_subqueries_from_expr(path, scope);
587 }
588 Expr::Lambda { body, .. } => {
589 collect_subqueries_from_expr(body, scope);
590 }
591 Expr::Extract { expr: inner, .. } | Expr::Interval { value: inner, .. } => {
592 collect_subqueries_from_expr(inner, scope);
593 }
594 Expr::Cube { exprs } | Expr::Rollup { exprs } | Expr::GroupingSets { sets: exprs } => {
595 for item in exprs {
596 collect_subqueries_from_expr(item, scope);
597 }
598 }
599 _ => {}
601 }
602}
603
604#[allow(clippy::collapsible_if)]
611fn resolve_selected_sources(scope: &mut Scope) {
612 for col in &scope.columns {
613 if let Some(table) = &col.table {
614 if let Some(source) = scope.sources.get(table) {
615 scope
616 .selected_sources
617 .entry(table.clone())
618 .or_insert_with(|| source.clone());
619 }
620 }
621 }
622}
623
624fn detect_correlation(scope: &mut Scope, outer_source_names: &[String]) {
632 let mut visible: Vec<String> = outer_source_names.to_vec();
634 visible.extend(scope.sources.keys().cloned());
635
636 detect_correlation_in_children(&mut scope.subquery_scopes, &visible);
638 detect_correlation_in_children(&mut scope.derived_table_scopes, &visible);
639 detect_correlation_in_children(&mut scope.union_scopes, &visible);
640 detect_correlation_in_children(&mut scope.cte_scopes, &visible);
641}
642
643#[allow(clippy::collapsible_if)]
644fn detect_correlation_in_children(children: &mut [Scope], outer_names: &[String]) {
645 for child in children.iter_mut() {
646 for col in &child.columns {
649 if let Some(table) = &col.table {
650 if outer_names.contains(table) && !child.sources.contains_key(table) {
651 child.external_columns.push(col.clone());
652 child.is_correlated = true;
653 }
654 }
655 }
656
657 detect_correlation(child, outer_names);
659 }
660}
661
662#[cfg(test)]
667mod tests {
668 use super::*;
669 use crate::dialects::Dialect;
670 use crate::parser::parse;
671
672 fn scope_for(sql: &str) -> Scope {
674 let ast = parse(sql, Dialect::Ansi).unwrap();
675 build_scope(&ast)
676 }
677
678 #[test]
681 fn test_simple_select() {
682 let scope = scope_for("SELECT a, b FROM t WHERE a > 1");
683 assert_eq!(scope.scope_type, ScopeType::Root);
684 assert!(scope.sources.contains_key("t"));
685 assert!(scope.columns.len() >= 2);
687 assert!(scope.external_columns.is_empty());
688 assert!(!scope.is_correlated);
689 }
690
691 #[test]
692 fn test_aliased_table() {
693 let scope = scope_for("SELECT t1.x FROM my_table t1");
694 assert!(scope.sources.contains_key("t1"));
695 assert!(!scope.sources.contains_key("my_table"));
696 }
697
698 #[test]
701 fn test_join_sources() {
702 let scope = scope_for("SELECT a.id, b.val FROM alpha a JOIN beta b ON a.id = b.id");
703 assert!(scope.sources.contains_key("a"));
704 assert!(scope.sources.contains_key("b"));
705 let on_cols: Vec<_> = scope.columns.iter().filter(|c| c.name == "id").collect();
707 assert!(on_cols.len() >= 2); }
709
710 #[test]
713 fn test_derived_table() {
714 let scope = scope_for("SELECT sub.x FROM (SELECT a AS x FROM t) sub");
715 assert!(scope.sources.contains_key("sub"));
716 assert_eq!(scope.derived_table_scopes.len(), 1);
717
718 let dt = &scope.derived_table_scopes[0];
719 assert_eq!(dt.scope_type, ScopeType::DerivedTable);
720 assert!(dt.sources.contains_key("t"));
721 }
722
723 #[test]
726 fn test_cte_scope() {
727 let scope = scope_for("WITH cte AS (SELECT id FROM t) SELECT cte.id FROM cte");
728 assert!(scope.sources.contains_key("cte"));
729 assert_eq!(scope.cte_scopes.len(), 1);
730
731 let cte = &scope.cte_scopes[0];
732 assert_eq!(cte.scope_type, ScopeType::Cte);
733 assert!(cte.sources.contains_key("t"));
734 }
735
736 #[test]
737 fn test_multiple_ctes() {
738 let scope = scope_for(
739 "WITH a AS (SELECT 1 AS x), b AS (SELECT 2 AS y) \
740 SELECT a.x, b.y FROM a, b",
741 );
742 assert_eq!(scope.cte_scopes.len(), 2);
743 assert!(scope.sources.contains_key("a"));
744 assert!(scope.sources.contains_key("b"));
745 }
746
747 #[test]
750 fn test_union_scopes() {
751 let scope = scope_for("SELECT a FROM t1 UNION ALL SELECT b FROM t2");
752 assert_eq!(scope.union_scopes.len(), 2);
753
754 let left = &scope.union_scopes[0];
755 assert_eq!(left.scope_type, ScopeType::Union);
756 assert!(left.sources.contains_key("t1"));
757
758 let right = &scope.union_scopes[1];
759 assert!(right.sources.contains_key("t2"));
760 }
761
762 #[test]
765 fn test_scalar_subquery() {
766 let scope = scope_for("SELECT (SELECT MAX(x) FROM t2) AS mx FROM t1");
767 assert_eq!(scope.subquery_scopes.len(), 1);
768 let sub = &scope.subquery_scopes[0];
769 assert_eq!(sub.scope_type, ScopeType::Subquery);
770 assert!(sub.sources.contains_key("t2"));
771 }
772
773 #[test]
776 fn test_exists_subquery() {
777 let scope =
778 scope_for("SELECT a FROM t1 WHERE EXISTS (SELECT 1 FROM t2 WHERE t2.id = t1.id)");
779 assert_eq!(scope.subquery_scopes.len(), 1);
780 let sub = &scope.subquery_scopes[0];
781 assert!(sub.sources.contains_key("t2"));
782 assert!(sub.is_correlated);
784 assert!(!sub.external_columns.is_empty());
785 let ext = &sub.external_columns[0];
786 assert_eq!(ext.table.as_deref(), Some("t1"));
787 assert_eq!(ext.name, "id");
788 }
789
790 #[test]
793 fn test_in_subquery() {
794 let scope = scope_for("SELECT a FROM t1 WHERE a IN (SELECT b FROM t2)");
795 assert_eq!(scope.subquery_scopes.len(), 1);
796 let sub = &scope.subquery_scopes[0];
797 assert!(sub.sources.contains_key("t2"));
798 assert!(!sub.is_correlated);
800 }
801
802 #[test]
805 fn test_correlated_subquery() {
806 let scope =
807 scope_for("SELECT a FROM t1 WHERE a = (SELECT MAX(b) FROM t2 WHERE t2.fk = t1.id)");
808 assert_eq!(scope.subquery_scopes.len(), 1);
809 let sub = &scope.subquery_scopes[0];
810 assert!(sub.is_correlated);
811 assert!(
812 sub.external_columns
813 .iter()
814 .any(|c| c.table.as_deref() == Some("t1"))
815 );
816 }
817
818 #[test]
821 fn test_nested_subqueries() {
822 let scope = scope_for(
823 "SELECT a FROM t1 WHERE a IN (SELECT b FROM t2 WHERE b > (SELECT MIN(c) FROM t3))",
824 );
825 assert_eq!(scope.subquery_scopes.len(), 1);
827
828 let in_sub = &scope.subquery_scopes[0];
829 assert!(in_sub.sources.contains_key("t2"));
830 assert_eq!(in_sub.subquery_scopes.len(), 1);
832 let inner = &in_sub.subquery_scopes[0];
833 assert!(inner.sources.contains_key("t3"));
834 }
835
836 #[test]
839 fn test_selected_sources() {
840 let scope = scope_for("SELECT a.x FROM alpha a JOIN beta b ON a.id = b.id");
841 assert!(scope.selected_sources.contains_key("a"));
843 }
844
845 #[test]
848 fn test_find_all_in_scope() {
849 let scope = scope_for("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
850 let t_cols = find_all_in_scope(&scope, &|c| c.table.as_deref() == Some("t"));
851 assert!(t_cols.len() >= 3);
853 }
854
855 #[test]
858 fn test_scope_walk() {
859 let scope = scope_for(
860 "WITH cte AS (SELECT 1 AS a) \
861 SELECT * FROM cte WHERE EXISTS (SELECT 1 FROM t)",
862 );
863 let mut count = 0;
864 scope.walk(&mut |_| count += 1);
865 assert!(count >= 3);
867 }
868
869 #[test]
872 fn test_qualified_wildcard() {
873 let scope = scope_for("SELECT t.* FROM t");
874 assert!(
875 scope
876 .columns
877 .iter()
878 .any(|c| c.table.as_deref() == Some("t") && c.name == "*")
879 );
880 assert!(scope.selected_sources.contains_key("t"));
881 }
882}