1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{
3 Expr, FunctionArgExpr, GroupByExpr, Query, Select, SelectItem, SetExpr,
4 Statement, TableFactor, With,
5};
6
7pub struct UnqualifiedColumnInJoin;
8
9impl Rule for UnqualifiedColumnInJoin {
10 fn name(&self) -> &'static str {
11 "Structure/UnqualifiedColumnInJoin"
12 }
13
14 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
15 if !ctx.parse_errors.is_empty() {
16 return Vec::new();
17 }
18 let mut diags = Vec::new();
19 for stmt in &ctx.statements {
20 if let Statement::Query(q) = stmt {
21 check_query(q, &ctx.source, self.name(), &mut diags);
22 }
23 }
24 diags
25 }
26}
27
28fn check_query(q: &Query, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
29 if let Some(With { cte_tables, .. }) = &q.with {
30 for cte in cte_tables {
31 check_query(&cte.query, src, rule, diags);
32 }
33 }
34 check_set_expr(&q.body, src, rule, diags);
35}
36
37fn check_set_expr(body: &SetExpr, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
38 match body {
39 SetExpr::Select(s) => check_select(s, src, rule, diags),
40 SetExpr::SetOperation { left, right, .. } => {
41 check_set_expr(left, src, rule, diags);
42 check_set_expr(right, src, rule, diags);
43 }
44 SetExpr::Query(q) => check_query(q, src, rule, diags),
45 _ => {}
46 }
47}
48
49fn check_select(sel: &Select, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
50 for twj in &sel.from {
52 recurse_table_factor(&twj.relation, src, rule, diags);
53 for join in &twj.joins {
54 recurse_table_factor(&join.relation, src, rule, diags);
55 }
56 }
57
58 let has_joins = sel.from.iter().any(|twj| !twj.joins.is_empty());
60 if !has_joins {
61 return;
62 }
63
64 for item in &sel.projection {
66 match item {
67 SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
68 find_unqualified(e, src, rule, diags);
69 }
70 SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(_, _) => {}
71 }
72 }
73
74 if let Some(w) = &sel.selection {
76 find_unqualified(w, src, rule, diags);
77 }
78
79 if let Some(h) = &sel.having {
81 find_unqualified(h, src, rule, diags);
82 }
83
84 if let GroupByExpr::Expressions(exprs, _) = &sel.group_by {
86 for g in exprs {
87 find_unqualified(g, src, rule, diags);
88 }
89 }
90}
91
92fn recurse_table_factor(tf: &TableFactor, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
93 if let TableFactor::Derived { subquery, .. } = tf {
94 check_query(subquery, src, rule, diags);
95 }
96}
97
98fn find_unqualified(expr: &Expr, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
99 match expr {
100 Expr::Identifier(i) => {
101 if let Some(off) = find_word_in_source(src, &i.value, 0) {
103 let (line, col) = offset_to_line_col(src, off);
104 diags.push(Diagnostic {
105 rule,
106 message: format!(
107 "Column '{}' is not qualified with a table name or alias; in a JOIN query, all columns should be table-qualified",
108 i.value
109 ),
110 line,
111 col,
112 });
113 }
114 }
115 Expr::CompoundIdentifier(_) => {} Expr::BinaryOp { left, right, .. } => {
117 find_unqualified(left, src, rule, diags);
118 find_unqualified(right, src, rule, diags);
119 }
120 Expr::UnaryOp { expr, .. } | Expr::Nested(expr) => {
121 find_unqualified(expr, src, rule, diags);
122 }
123 Expr::Function(f) => {
124 if let sqlparser::ast::FunctionArguments::List(arg_list) = &f.args {
125 for arg in &arg_list.args {
126 if let sqlparser::ast::FunctionArg::Unnamed(arg_expr) = arg {
127 if let FunctionArgExpr::Expr(e) = arg_expr {
128 find_unqualified(e, src, rule, diags);
129 }
130 }
131 }
132 }
133 }
134 Expr::IsNull(e) | Expr::IsNotNull(e) => find_unqualified(e, src, rule, diags),
135 Expr::Between { expr, low, high, .. } => {
136 find_unqualified(expr, src, rule, diags);
137 find_unqualified(low, src, rule, diags);
138 find_unqualified(high, src, rule, diags);
139 }
140 Expr::InList { expr, list, .. } => {
141 find_unqualified(expr, src, rule, diags);
142 for e in list {
143 find_unqualified(e, src, rule, diags);
144 }
145 }
146 Expr::Case { operand, conditions, results, else_result } => {
147 if let Some(e) = operand {
148 find_unqualified(e, src, rule, diags);
149 }
150 for (c, r) in conditions.iter().zip(results.iter()) {
151 find_unqualified(c, src, rule, diags);
152 find_unqualified(r, src, rule, diags);
153 }
154 if let Some(e) = else_result {
155 find_unqualified(e, src, rule, diags);
156 }
157 }
158 _ => {}
159 }
160}
161
162fn find_word_in_source(src: &str, word: &str, start: usize) -> Option<usize> {
163 let bytes = src.as_bytes();
164 let wbytes = word.as_bytes();
165 let wlen = wbytes.len();
166 if wlen == 0 {
167 return None;
168 }
169 let mut i = start;
170 while i + wlen <= bytes.len() {
171 if bytes[i..i + wlen].eq_ignore_ascii_case(wbytes) {
172 let before_ok = i == 0 || (!is_wc(bytes[i - 1]) && bytes[i - 1] != b'.');
173 let after_ok = i + wlen >= bytes.len() || !is_wc(bytes[i + wlen]);
174 if before_ok && after_ok {
175 return Some(i);
176 }
177 }
178 i += 1;
179 }
180 None
181}
182
183fn is_wc(b: u8) -> bool {
184 b.is_ascii_alphanumeric() || b == b'_'
185}
186
187fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
188 let before = &source[..offset.min(source.len())];
189 let line = before.chars().filter(|&c| c == '\n').count() + 1;
190 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
191 (line, col)
192}