1use std::collections::HashMap;
24
25use crate::ast::*;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33pub enum ScopeType {
34 Root,
36 Subquery,
38 DerivedTable,
40 Cte,
42 Union,
44 Udtf,
46}
47
48#[derive(Debug, Clone)]
55pub enum Source {
56 Table(TableRef),
58 Scope(Box<Scope>),
60}
61
62#[derive(Debug, Clone, PartialEq)]
68pub struct ColumnRef {
69 pub table: Option<String>,
71 pub name: String,
73}
74
75#[derive(Debug, Clone)]
85pub struct Scope {
86 pub scope_type: ScopeType,
88
89 pub sources: HashMap<String, Source>,
93
94 pub columns: Vec<ColumnRef>,
97
98 pub external_columns: Vec<ColumnRef>,
100
101 pub derived_table_scopes: Vec<Scope>,
103
104 pub subquery_scopes: Vec<Scope>,
106
107 pub union_scopes: Vec<Scope>,
109
110 pub cte_scopes: Vec<Scope>,
112
113 pub selected_sources: HashMap<String, Source>,
116
117 pub is_correlated: bool,
119
120 expression: Option<ScopeExpression>,
123}
124
125#[derive(Debug, Clone)]
127#[allow(dead_code)]
128enum ScopeExpression {
129 Statement(Statement),
130}
131
132impl Scope {
133 fn new(scope_type: ScopeType) -> Self {
134 Self {
135 scope_type,
136 sources: HashMap::new(),
137 columns: Vec::new(),
138 external_columns: Vec::new(),
139 derived_table_scopes: Vec::new(),
140 subquery_scopes: Vec::new(),
141 union_scopes: Vec::new(),
142 cte_scopes: Vec::new(),
143 selected_sources: HashMap::new(),
144 is_correlated: false,
145 expression: None,
146 }
147 }
148
149 #[must_use]
151 pub fn source_names(&self) -> Vec<&str> {
152 self.sources.keys().map(String::as_str).collect()
153 }
154
155 #[must_use]
158 pub fn child_scopes(&self) -> Vec<&Scope> {
159 let mut children: Vec<&Scope> = Vec::new();
160 children.extend(self.derived_table_scopes.iter());
161 children.extend(self.subquery_scopes.iter());
162 children.extend(self.union_scopes.iter());
163 children.extend(self.cte_scopes.iter());
164 children
165 }
166
167 pub fn walk<F>(&self, visitor: &mut F)
169 where
170 F: FnMut(&Scope),
171 {
172 visitor(self);
173 for child in self.child_scopes() {
174 child.walk(visitor);
175 }
176 }
177}
178
179#[must_use]
188pub fn build_scope(statement: &Statement) -> Scope {
189 let mut scope = Scope::new(ScopeType::Root);
190 scope.expression = Some(ScopeExpression::Statement(statement.clone()));
191 build_scope_inner(statement, &mut scope, ScopeType::Root);
192 resolve_selected_sources(&mut scope);
193 detect_correlation(&mut scope, &[]);
194 scope
195}
196
197#[must_use]
202pub fn find_all_in_scope<'a, F>(scope: &'a Scope, predicate: &F) -> Vec<&'a ColumnRef>
203where
204 F: Fn(&ColumnRef) -> bool,
205{
206 scope.columns.iter().filter(|c| predicate(c)).collect()
207}
208
209fn build_scope_inner(statement: &Statement, scope: &mut Scope, _scope_type: ScopeType) {
214 match statement {
215 Statement::Select(sel) => build_select_scope(sel, scope),
216 Statement::SetOperation(set_op) => build_set_operation_scope(set_op, scope),
217 Statement::CreateView(cv) => {
218 build_scope_inner(&cv.query, scope, ScopeType::Root);
220 }
221 Statement::Insert(ins) => {
222 if let InsertSource::Query(q) = &ins.source {
223 let mut sub = Scope::new(ScopeType::Subquery);
224 build_scope_inner(q, &mut sub, ScopeType::Subquery);
225 resolve_selected_sources(&mut sub);
226 scope.subquery_scopes.push(sub);
227 }
228 }
229 Statement::Delete(del) => {
230 if let Some(wh) = &del.where_clause {
232 collect_columns_from_expr(wh, scope);
233 }
234 }
235 Statement::Update(upd) => {
236 for (_col, expr) in &upd.assignments {
238 collect_columns_from_expr(expr, scope);
239 }
240 if let Some(wh) = &upd.where_clause {
241 collect_columns_from_expr(wh, scope);
242 }
243 }
244 Statement::Explain(expl) => {
245 build_scope_inner(&expl.statement, scope, _scope_type);
246 }
247 _ => {}
249 }
250}
251
252fn build_select_scope(sel: &SelectStatement, scope: &mut Scope) {
254 for cte in &sel.ctes {
256 let mut cte_scope = Scope::new(ScopeType::Cte);
257 cte_scope.expression = Some(ScopeExpression::Statement(*cte.query.clone()));
258 build_scope_inner(&cte.query, &mut cte_scope, ScopeType::Cte);
259 resolve_selected_sources(&mut cte_scope);
260
261 scope
263 .sources
264 .insert(cte.name.clone(), Source::Scope(Box::new(cte_scope.clone())));
265 scope.cte_scopes.push(cte_scope);
266 }
267
268 if let Some(from) = &sel.from {
270 process_table_source(&from.source, scope);
271 }
272
273 for join in &sel.joins {
275 process_table_source(&join.table, scope);
276 if let Some(on) = &join.on {
277 collect_columns_from_expr(on, scope);
278 }
279 }
280
281 for item in &sel.columns {
283 match item {
284 SelectItem::Expr { expr, .. } => {
285 collect_columns_from_expr(expr, scope);
286 collect_subqueries_from_expr(expr, scope);
287 }
288 SelectItem::QualifiedWildcard { table } => {
289 scope.columns.push(ColumnRef {
292 table: Some(table.clone()),
293 name: "*".to_string(),
294 });
295 }
296 SelectItem::Wildcard => {}
297 }
298 }
299
300 if let Some(wh) = &sel.where_clause {
302 collect_columns_from_expr(wh, scope);
303 collect_subqueries_from_expr(wh, scope);
304 }
305
306 for expr in &sel.group_by {
308 collect_columns_from_expr(expr, scope);
309 }
310
311 if let Some(having) = &sel.having {
313 collect_columns_from_expr(having, scope);
314 collect_subqueries_from_expr(having, scope);
315 }
316
317 for item in &sel.order_by {
319 collect_columns_from_expr(&item.expr, scope);
320 }
321
322 if let Some(qualify) = &sel.qualify {
324 collect_columns_from_expr(qualify, scope);
325 collect_subqueries_from_expr(qualify, scope);
326 }
327}
328
329fn build_set_operation_scope(set_op: &SetOperationStatement, scope: &mut Scope) {
331 let mut left_scope = Scope::new(ScopeType::Union);
333 build_scope_inner(&set_op.left, &mut left_scope, ScopeType::Union);
334 resolve_selected_sources(&mut left_scope);
335 scope.union_scopes.push(left_scope);
336
337 let mut right_scope = Scope::new(ScopeType::Union);
338 build_scope_inner(&set_op.right, &mut right_scope, ScopeType::Union);
339 resolve_selected_sources(&mut right_scope);
340 scope.union_scopes.push(right_scope);
341
342 for item in &set_op.order_by {
344 collect_columns_from_expr(&item.expr, scope);
345 }
346}
347
348fn process_table_source(source: &TableSource, scope: &mut Scope) {
350 match source {
351 TableSource::Table(table_ref) => {
352 let key = table_ref
353 .alias
354 .as_deref()
355 .unwrap_or(&table_ref.name)
356 .to_string();
357 scope.sources.insert(key, Source::Table(table_ref.clone()));
358 }
359 TableSource::Subquery { query, alias } => {
360 let mut dt_scope = Scope::new(ScopeType::DerivedTable);
361 dt_scope.expression = Some(ScopeExpression::Statement(*query.clone()));
362 build_scope_inner(query, &mut dt_scope, ScopeType::DerivedTable);
363 resolve_selected_sources(&mut dt_scope);
364
365 if let Some(alias) = alias {
366 scope
367 .sources
368 .insert(alias.clone(), Source::Scope(Box::new(dt_scope.clone())));
369 }
370 scope.derived_table_scopes.push(dt_scope);
371 }
372 TableSource::TableFunction { alias, .. } => {
373 if let Some(alias) = alias {
374 scope.sources.insert(
376 alias.clone(),
377 Source::Table(TableRef {
378 catalog: None,
379 schema: None,
380 name: alias.clone(),
381 alias: None,
382 name_quote_style: QuoteStyle::None,
383 }),
384 );
385 }
386 }
387 TableSource::Lateral { source } => {
388 process_table_source(source, scope);
389 }
390 TableSource::Pivot { source, alias, .. } | TableSource::Unpivot { source, alias, .. } => {
391 process_table_source(source, scope);
392 if let Some(alias) = alias {
393 scope.sources.insert(
394 alias.clone(),
395 Source::Table(TableRef {
396 catalog: None,
397 schema: None,
398 name: alias.clone(),
399 alias: None,
400 name_quote_style: QuoteStyle::None,
401 }),
402 );
403 }
404 }
405 TableSource::Unnest { alias, .. } => {
406 if let Some(alias) = alias {
407 scope.sources.insert(
408 alias.clone(),
409 Source::Table(TableRef {
410 catalog: None,
411 schema: None,
412 name: alias.clone(),
413 alias: None,
414 name_quote_style: QuoteStyle::None,
415 }),
416 );
417 }
418 }
419 }
420}
421
422fn collect_columns_from_expr(expr: &Expr, scope: &mut Scope) {
429 expr.walk(&mut |e| {
430 match e {
431 Expr::Column { table, name, .. } => {
432 scope.columns.push(ColumnRef {
433 table: table.clone(),
434 name: name.clone(),
435 });
436 true
437 }
438 Expr::Subquery(_) | Expr::Exists { .. } | Expr::InSubquery { .. } => false,
440 _ => true,
441 }
442 });
443}
444
445fn collect_subqueries_from_expr(expr: &Expr, scope: &mut Scope) {
447 match expr {
448 Expr::Subquery(stmt) => {
449 let mut sub = Scope::new(ScopeType::Subquery);
450 sub.expression = Some(ScopeExpression::Statement(*stmt.clone()));
451 build_scope_inner(stmt, &mut sub, ScopeType::Subquery);
452 resolve_selected_sources(&mut sub);
453 scope.subquery_scopes.push(sub);
454 }
455 Expr::Exists { subquery, .. } => {
456 let mut sub = Scope::new(ScopeType::Subquery);
457 sub.expression = Some(ScopeExpression::Statement(*subquery.clone()));
458 build_scope_inner(subquery, &mut sub, ScopeType::Subquery);
459 resolve_selected_sources(&mut sub);
460 scope.subquery_scopes.push(sub);
461 }
462 Expr::InSubquery {
463 expr: left,
464 subquery,
465 ..
466 } => {
467 collect_columns_from_expr(left, scope);
469
470 let mut sub = Scope::new(ScopeType::Subquery);
471 sub.expression = Some(ScopeExpression::Statement(*subquery.clone()));
472 build_scope_inner(subquery, &mut sub, ScopeType::Subquery);
473 resolve_selected_sources(&mut sub);
474 scope.subquery_scopes.push(sub);
475 }
476 _ => {
477 walk_expr_for_subqueries(expr, scope);
479 }
480 }
481}
482
483fn walk_expr_for_subqueries(expr: &Expr, scope: &mut Scope) {
486 match expr {
487 Expr::BinaryOp { left, right, .. } => {
488 collect_subqueries_from_expr(left, scope);
489 collect_subqueries_from_expr(right, scope);
490 }
491 Expr::UnaryOp { expr: inner, .. } => {
492 collect_subqueries_from_expr(inner, scope);
493 }
494 Expr::Function { args, filter, .. } => {
495 for arg in args {
496 collect_subqueries_from_expr(arg, scope);
497 }
498 if let Some(f) = filter {
499 collect_subqueries_from_expr(f, scope);
500 }
501 }
502 Expr::Nested(inner) => {
503 collect_subqueries_from_expr(inner, scope);
504 }
505 Expr::Case {
506 operand,
507 when_clauses,
508 else_clause,
509 } => {
510 if let Some(op) = operand {
511 collect_subqueries_from_expr(op, scope);
512 }
513 for (cond, result) in when_clauses {
514 collect_subqueries_from_expr(cond, scope);
515 collect_subqueries_from_expr(result, scope);
516 }
517 if let Some(el) = else_clause {
518 collect_subqueries_from_expr(el, scope);
519 }
520 }
521 Expr::Between {
522 expr: inner,
523 low,
524 high,
525 ..
526 } => {
527 collect_subqueries_from_expr(inner, scope);
528 collect_subqueries_from_expr(low, scope);
529 collect_subqueries_from_expr(high, scope);
530 }
531 Expr::InList {
532 expr: inner, list, ..
533 } => {
534 collect_subqueries_from_expr(inner, scope);
535 for item in list {
536 collect_subqueries_from_expr(item, scope);
537 }
538 }
539 Expr::Cast { expr: inner, .. } | Expr::TryCast { expr: inner, .. } => {
540 collect_subqueries_from_expr(inner, scope);
541 }
542 Expr::Coalesce(items) | Expr::ArrayLiteral(items) | Expr::Tuple(items) => {
543 for item in items {
544 collect_subqueries_from_expr(item, scope);
545 }
546 }
547 Expr::If {
548 condition,
549 true_val,
550 false_val,
551 } => {
552 collect_subqueries_from_expr(condition, scope);
553 collect_subqueries_from_expr(true_val, scope);
554 if let Some(fv) = false_val {
555 collect_subqueries_from_expr(fv, scope);
556 }
557 }
558 Expr::IsNull { expr: inner, .. } | Expr::IsBool { expr: inner, .. } => {
559 collect_subqueries_from_expr(inner, scope);
560 }
561 Expr::Like {
562 expr: inner,
563 pattern,
564 ..
565 }
566 | Expr::ILike {
567 expr: inner,
568 pattern,
569 ..
570 } => {
571 collect_subqueries_from_expr(inner, scope);
572 collect_subqueries_from_expr(pattern, scope);
573 }
574 Expr::Alias { expr: inner, .. } | Expr::Collate { expr: inner, .. } => {
575 collect_subqueries_from_expr(inner, scope);
576 }
577 Expr::NullIf {
578 expr: inner,
579 r#else,
580 } => {
581 collect_subqueries_from_expr(inner, scope);
582 collect_subqueries_from_expr(r#else, scope);
583 }
584 Expr::AnyOp {
585 expr: inner, right, ..
586 }
587 | Expr::AllOp {
588 expr: inner, right, ..
589 } => {
590 collect_subqueries_from_expr(inner, scope);
591 collect_subqueries_from_expr(right, scope);
592 }
593 Expr::ArrayIndex { expr: inner, index } => {
594 collect_subqueries_from_expr(inner, scope);
595 collect_subqueries_from_expr(index, scope);
596 }
597 Expr::JsonAccess {
598 expr: inner, path, ..
599 } => {
600 collect_subqueries_from_expr(inner, scope);
601 collect_subqueries_from_expr(path, scope);
602 }
603 Expr::Lambda { body, .. } => {
604 collect_subqueries_from_expr(body, scope);
605 }
606 Expr::Extract { expr: inner, .. } | Expr::Interval { value: inner, .. } => {
607 collect_subqueries_from_expr(inner, scope);
608 }
609 Expr::Cube { exprs } | Expr::Rollup { exprs } | Expr::GroupingSets { sets: exprs } => {
610 for item in exprs {
611 collect_subqueries_from_expr(item, scope);
612 }
613 }
614 _ => {}
616 }
617}
618
619#[allow(clippy::collapsible_if)]
626fn resolve_selected_sources(scope: &mut Scope) {
627 for col in &scope.columns {
628 if let Some(table) = &col.table {
629 if let Some(source) = scope.sources.get(table) {
630 scope
631 .selected_sources
632 .entry(table.clone())
633 .or_insert_with(|| source.clone());
634 }
635 }
636 }
637}
638
639fn detect_correlation(scope: &mut Scope, outer_source_names: &[String]) {
647 let mut visible: Vec<String> = outer_source_names.to_vec();
649 visible.extend(scope.sources.keys().cloned());
650
651 detect_correlation_in_children(&mut scope.subquery_scopes, &visible);
653 detect_correlation_in_children(&mut scope.derived_table_scopes, &visible);
654 detect_correlation_in_children(&mut scope.union_scopes, &visible);
655 detect_correlation_in_children(&mut scope.cte_scopes, &visible);
656}
657
658#[allow(clippy::collapsible_if)]
659fn detect_correlation_in_children(children: &mut [Scope], outer_names: &[String]) {
660 for child in children.iter_mut() {
661 for col in &child.columns {
664 if let Some(table) = &col.table {
665 if outer_names.contains(table) && !child.sources.contains_key(table) {
666 child.external_columns.push(col.clone());
667 child.is_correlated = true;
668 }
669 }
670 }
671
672 detect_correlation(child, outer_names);
674 }
675}
676
677#[cfg(test)]
682mod tests {
683 use super::*;
684 use crate::dialects::Dialect;
685 use crate::parser::parse;
686
687 fn scope_for(sql: &str) -> Scope {
689 let ast = parse(sql, Dialect::Ansi).unwrap();
690 build_scope(&ast)
691 }
692
693 #[test]
696 fn test_simple_select() {
697 let scope = scope_for("SELECT a, b FROM t WHERE a > 1");
698 assert_eq!(scope.scope_type, ScopeType::Root);
699 assert!(scope.sources.contains_key("t"));
700 assert!(scope.columns.len() >= 2);
702 assert!(scope.external_columns.is_empty());
703 assert!(!scope.is_correlated);
704 }
705
706 #[test]
707 fn test_aliased_table() {
708 let scope = scope_for("SELECT t1.x FROM my_table t1");
709 assert!(scope.sources.contains_key("t1"));
710 assert!(!scope.sources.contains_key("my_table"));
711 }
712
713 #[test]
716 fn test_join_sources() {
717 let scope = scope_for("SELECT a.id, b.val FROM alpha a JOIN beta b ON a.id = b.id");
718 assert!(scope.sources.contains_key("a"));
719 assert!(scope.sources.contains_key("b"));
720 let on_cols: Vec<_> = scope.columns.iter().filter(|c| c.name == "id").collect();
722 assert!(on_cols.len() >= 2); }
724
725 #[test]
728 fn test_derived_table() {
729 let scope = scope_for("SELECT sub.x FROM (SELECT a AS x FROM t) sub");
730 assert!(scope.sources.contains_key("sub"));
731 assert_eq!(scope.derived_table_scopes.len(), 1);
732
733 let dt = &scope.derived_table_scopes[0];
734 assert_eq!(dt.scope_type, ScopeType::DerivedTable);
735 assert!(dt.sources.contains_key("t"));
736 }
737
738 #[test]
741 fn test_cte_scope() {
742 let scope = scope_for("WITH cte AS (SELECT id FROM t) SELECT cte.id FROM cte");
743 assert!(scope.sources.contains_key("cte"));
744 assert_eq!(scope.cte_scopes.len(), 1);
745
746 let cte = &scope.cte_scopes[0];
747 assert_eq!(cte.scope_type, ScopeType::Cte);
748 assert!(cte.sources.contains_key("t"));
749 }
750
751 #[test]
752 fn test_multiple_ctes() {
753 let scope = scope_for(
754 "WITH a AS (SELECT 1 AS x), b AS (SELECT 2 AS y) \
755 SELECT a.x, b.y FROM a, b",
756 );
757 assert_eq!(scope.cte_scopes.len(), 2);
758 assert!(scope.sources.contains_key("a"));
759 assert!(scope.sources.contains_key("b"));
760 }
761
762 #[test]
765 fn test_union_scopes() {
766 let scope = scope_for("SELECT a FROM t1 UNION ALL SELECT b FROM t2");
767 assert_eq!(scope.union_scopes.len(), 2);
768
769 let left = &scope.union_scopes[0];
770 assert_eq!(left.scope_type, ScopeType::Union);
771 assert!(left.sources.contains_key("t1"));
772
773 let right = &scope.union_scopes[1];
774 assert!(right.sources.contains_key("t2"));
775 }
776
777 #[test]
780 fn test_scalar_subquery() {
781 let scope = scope_for("SELECT (SELECT MAX(x) FROM t2) AS mx FROM t1");
782 assert_eq!(scope.subquery_scopes.len(), 1);
783 let sub = &scope.subquery_scopes[0];
784 assert_eq!(sub.scope_type, ScopeType::Subquery);
785 assert!(sub.sources.contains_key("t2"));
786 }
787
788 #[test]
791 fn test_exists_subquery() {
792 let scope =
793 scope_for("SELECT a FROM t1 WHERE EXISTS (SELECT 1 FROM t2 WHERE t2.id = t1.id)");
794 assert_eq!(scope.subquery_scopes.len(), 1);
795 let sub = &scope.subquery_scopes[0];
796 assert!(sub.sources.contains_key("t2"));
797 assert!(sub.is_correlated);
799 assert!(!sub.external_columns.is_empty());
800 let ext = &sub.external_columns[0];
801 assert_eq!(ext.table.as_deref(), Some("t1"));
802 assert_eq!(ext.name, "id");
803 }
804
805 #[test]
808 fn test_in_subquery() {
809 let scope = scope_for("SELECT a FROM t1 WHERE a IN (SELECT b FROM t2)");
810 assert_eq!(scope.subquery_scopes.len(), 1);
811 let sub = &scope.subquery_scopes[0];
812 assert!(sub.sources.contains_key("t2"));
813 assert!(!sub.is_correlated);
815 }
816
817 #[test]
820 fn test_correlated_subquery() {
821 let scope =
822 scope_for("SELECT a FROM t1 WHERE a = (SELECT MAX(b) FROM t2 WHERE t2.fk = t1.id)");
823 assert_eq!(scope.subquery_scopes.len(), 1);
824 let sub = &scope.subquery_scopes[0];
825 assert!(sub.is_correlated);
826 assert!(
827 sub.external_columns
828 .iter()
829 .any(|c| c.table.as_deref() == Some("t1"))
830 );
831 }
832
833 #[test]
836 fn test_nested_subqueries() {
837 let scope = scope_for(
838 "SELECT a FROM t1 WHERE a IN (SELECT b FROM t2 WHERE b > (SELECT MIN(c) FROM t3))",
839 );
840 assert_eq!(scope.subquery_scopes.len(), 1);
842
843 let in_sub = &scope.subquery_scopes[0];
844 assert!(in_sub.sources.contains_key("t2"));
845 assert_eq!(in_sub.subquery_scopes.len(), 1);
847 let inner = &in_sub.subquery_scopes[0];
848 assert!(inner.sources.contains_key("t3"));
849 }
850
851 #[test]
854 fn test_selected_sources() {
855 let scope = scope_for("SELECT a.x FROM alpha a JOIN beta b ON a.id = b.id");
856 assert!(scope.selected_sources.contains_key("a"));
858 }
859
860 #[test]
863 fn test_find_all_in_scope() {
864 let scope = scope_for("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
865 let t_cols = find_all_in_scope(&scope, &|c| c.table.as_deref() == Some("t"));
866 assert!(t_cols.len() >= 3);
868 }
869
870 #[test]
873 fn test_scope_walk() {
874 let scope = scope_for(
875 "WITH cte AS (SELECT 1 AS a) \
876 SELECT * FROM cte WHERE EXISTS (SELECT 1 FROM t)",
877 );
878 let mut count = 0;
879 scope.walk(&mut |_| count += 1);
880 assert!(count >= 3);
882 }
883
884 #[test]
887 fn test_qualified_wildcard() {
888 let scope = scope_for("SELECT t.* FROM t");
889 assert!(
890 scope
891 .columns
892 .iter()
893 .any(|c| c.table.as_deref() == Some("t") && c.name == "*")
894 );
895 assert!(scope.selected_sources.contains_key("t"));
896 }
897}