sqrust_rules/ambiguous/
self_join.rs1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{JoinConstraint, JoinOperator, Query, SetExpr, Statement, TableFactor};
3use std::collections::HashMap;
4
5pub struct SelfJoin;
6
7impl Rule for SelfJoin {
8 fn name(&self) -> &'static str {
9 "Ambiguous/SelfJoin"
10 }
11
12 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
13 if !ctx.parse_errors.is_empty() {
14 return Vec::new();
15 }
16
17 let mut diags = Vec::new();
18 for stmt in &ctx.statements {
19 if let Statement::Query(query) = stmt {
20 check_query(query, &ctx.source, &mut diags);
21 }
22 }
23 diags
24 }
25}
26
27fn check_query(query: &Query, source: &str, diags: &mut Vec<Diagnostic>) {
28 if let Some(with) = &query.with {
29 for cte in &with.cte_tables {
30 check_query(&cte.query, source, diags);
31 }
32 }
33 check_set_expr(&query.body, source, diags);
34}
35
36fn check_set_expr(expr: &SetExpr, source: &str, diags: &mut Vec<Diagnostic>) {
37 match expr {
38 SetExpr::Select(sel) => {
39 for twj in &sel.from {
40 let mut refs: Vec<(String, Option<String>)> = Vec::new();
42
43 collect_table_ref(&twj.relation, &mut refs, source, diags);
44 for join in &twj.joins {
45 collect_table_ref(&join.relation, &mut refs, source, diags);
46
47 if let JoinOperator::Inner(JoinConstraint::On(_))
49 | JoinOperator::LeftOuter(JoinConstraint::On(_))
50 | JoinOperator::RightOuter(JoinConstraint::On(_))
51 | JoinOperator::FullOuter(JoinConstraint::On(_)) = &join.join_operator
52 {
53 }
55 }
56
57 detect_self_joins(&refs, source, diags);
59 }
60
61 for twj in &sel.from {
63 recurse_subqueries_in_factor(&twj.relation, source, diags);
64 for join in &twj.joins {
65 recurse_subqueries_in_factor(&join.relation, source, diags);
66 }
67 }
68 }
69 SetExpr::Query(inner) => check_query(inner, source, diags),
70 SetExpr::SetOperation { left, right, .. } => {
71 check_set_expr(left, source, diags);
72 check_set_expr(right, source, diags);
73 }
74 _ => {}
75 }
76}
77
78fn collect_table_ref(
82 factor: &TableFactor,
83 refs: &mut Vec<(String, Option<String>)>,
84 _source: &str,
85 _diags: &mut Vec<Diagnostic>,
86) {
87 if let TableFactor::Table { name, alias, .. } = factor {
88 let table_name = name
89 .0
90 .last()
91 .map(|i| i.value.to_lowercase())
92 .unwrap_or_default();
93 let alias_name = alias.as_ref().map(|a| a.name.value.to_lowercase());
94 refs.push((table_name, alias_name));
95 }
96}
97
98fn recurse_subqueries_in_factor(
100 factor: &TableFactor,
101 source: &str,
102 diags: &mut Vec<Diagnostic>,
103) {
104 if let TableFactor::Derived { subquery, .. } = factor {
105 check_query(subquery, source, diags);
106 }
107}
108
109fn detect_self_joins(
113 refs: &[(String, Option<String>)],
114 source: &str,
115 diags: &mut Vec<Diagnostic>,
116) {
117 let mut by_name: HashMap<&str, Vec<Option<&str>>> = HashMap::new();
120 for (name, alias) in refs {
121 by_name
122 .entry(name.as_str())
123 .or_default()
124 .push(alias.as_deref());
125 }
126
127 for (table_name, aliases) in &by_name {
128 if aliases.len() < 2 {
129 continue;
130 }
131
132 let is_ambiguous = aliases.iter().any(|a| a.is_none())
135 || {
136 let named: Vec<&str> = aliases.iter().filter_map(|a| *a).collect();
138 has_duplicate(&named)
139 };
140
141 if is_ambiguous {
142 let pos = find_second_occurrence(source, table_name);
144 let (line, col) = offset_to_line_col(source, pos);
145 diags.push(Diagnostic {
146 rule: "Ambiguous/SelfJoin",
147 message: format!(
148 "Table '{}' is joined to itself without distinct aliases",
149 table_name
150 ),
151 line,
152 col,
153 });
154 }
155 }
156}
157
158fn has_duplicate(names: &[&str]) -> bool {
160 for i in 0..names.len() {
161 for j in (i + 1)..names.len() {
162 if names[i] == names[j] {
163 return true;
164 }
165 }
166 }
167 false
168}
169
170fn find_second_occurrence(source: &str, name: &str) -> usize {
173 find_nth_occurrence(source, name, 1)
174}
175
176fn find_nth_occurrence(source: &str, name: &str, nth: usize) -> usize {
179 let bytes = source.as_bytes();
180 let name_bytes: Vec<u8> = name.bytes().map(|b| b.to_ascii_uppercase()).collect();
181 let name_len = name_bytes.len();
182 let src_len = bytes.len();
183
184 let mut count = 0usize;
185 let mut i = 0usize;
186
187 while i + name_len <= src_len {
188 let before_ok = i == 0 || {
189 let b = bytes[i - 1];
190 !b.is_ascii_alphanumeric() && b != b'_'
191 };
192
193 if before_ok {
194 let matches = bytes[i..i + name_len]
195 .iter()
196 .zip(name_bytes.iter())
197 .all(|(&a, &b)| a.to_ascii_uppercase() == b);
198
199 if matches {
200 let after = i + name_len;
201 let after_ok = after >= src_len || {
202 let b = bytes[after];
203 !b.is_ascii_alphanumeric() && b != b'_'
204 };
205
206 if after_ok {
207 if count == nth {
208 return i;
209 }
210 count += 1;
211 }
212 }
213 }
214
215 i += 1;
216 }
217
218 0
219}
220
221fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
223 let before = &source[..offset];
224 let line = before.chars().filter(|&c| c == '\n').count() + 1;
225 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
226 (line, col)
227}