1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Query, Select, SetExpr, Statement, TableFactor, TableWithJoins, With};
3use std::collections::{HashMap, HashSet};
4
5pub struct DuplicateJoin;
6
7impl Rule for DuplicateJoin {
8 fn name(&self) -> &'static str {
9 "Lint/DuplicateJoin"
10 }
11
12 fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
13 if !ctx.parse_errors.is_empty() {
14 return Vec::new();
15 }
16 let mut diags = Vec::new();
17 for stmt in &ctx.statements {
18 check_stmt(stmt, &ctx.source, "Lint/DuplicateJoin", &mut diags);
19 }
20 diags
21 }
22}
23
24fn check_stmt(stmt: &Statement, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
25 if let Statement::Query(q) = stmt {
26 check_query(q, src, rule, diags);
27 }
28}
29
30fn check_query(q: &Query, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
31 if let Some(With { cte_tables, .. }) = &q.with {
32 for cte in cte_tables {
33 check_query(&cte.query, src, rule, diags);
34 }
35 }
36 check_set_expr(&q.body, src, rule, diags);
37}
38
39fn check_set_expr(body: &SetExpr, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
40 match body {
41 SetExpr::Select(s) => check_select(s, src, rule, diags),
42 SetExpr::SetOperation { left, right, .. } => {
43 check_set_expr(left, src, rule, diags);
44 check_set_expr(right, src, rule, diags);
45 }
46 SetExpr::Query(q) => check_query(q, src, rule, diags),
47 _ => {}
48 }
49}
50
51fn check_select(sel: &Select, src: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
52 for twj in &sel.from {
53 check_table_with_joins(twj, src, rule, diags);
54 }
55}
56
57fn check_table_with_joins(
58 twj: &TableWithJoins,
59 src: &str,
60 rule: &'static str,
61 diags: &mut Vec<Diagnostic>,
62) {
63 let mut seen: HashMap<String, usize> = HashMap::new();
65 let mut already_flagged: HashSet<String> = HashSet::new();
66 let mut last_off: usize = 0;
67
68 if let Some((name, off)) = table_factor_name(&twj.relation, src, last_off) {
70 last_off = off + 1;
71 seen.insert(name, off);
72 }
73
74 check_factor_subqueries(&twj.relation, src, rule, diags);
76
77 for join in &twj.joins {
79 check_factor_subqueries(&join.relation, src, rule, diags);
80 if let Some((name, off)) = table_factor_name(&join.relation, src, last_off) {
81 last_off = off + 1;
82 if seen.contains_key(&name) && !already_flagged.contains(&name) {
83 let (line, col) = offset_to_line_col(src, off);
84 diags.push(Diagnostic {
85 rule,
86 message: format!(
87 "Table '{}' is joined more than once in the same FROM clause",
88 name
89 ),
90 line,
91 col,
92 });
93 already_flagged.insert(name.clone());
94 } else if !seen.contains_key(&name) {
95 seen.insert(name, off);
96 }
97 }
98 }
99}
100
101fn table_factor_name(tf: &TableFactor, src: &str, start: usize) -> Option<(String, usize)> {
102 match tf {
103 TableFactor::Table { name, .. } => {
104 let full_name = name
105 .0
106 .iter()
107 .map(|i| i.value.to_lowercase())
108 .collect::<Vec<_>>()
109 .join(".");
110 let last = name.0.last()?.value.clone();
111 let off = find_word_in_source(src, &last, start)?;
112 Some((full_name, off))
113 }
114 _ => None,
115 }
116}
117
118fn check_factor_subqueries(
119 tf: &TableFactor,
120 src: &str,
121 rule: &'static str,
122 diags: &mut Vec<Diagnostic>,
123) {
124 if let TableFactor::Derived { subquery, .. } = tf {
125 check_query(subquery, src, rule, diags);
126 }
127}
128
129fn find_word_in_source(src: &str, word: &str, start: usize) -> Option<usize> {
130 let bytes = src.as_bytes();
131 let wbytes = word.as_bytes();
132 let wlen = wbytes.len();
133 if wlen == 0 {
134 return None;
135 }
136 let mut i = start;
137 while i + wlen <= bytes.len() {
138 if bytes[i..i + wlen].eq_ignore_ascii_case(wbytes) {
139 let before_ok = i == 0 || !is_wc(bytes[i - 1]);
140 let after_ok = i + wlen >= bytes.len() || !is_wc(bytes[i + wlen]);
141 if before_ok && after_ok {
142 return Some(i);
143 }
144 }
145 i += 1;
146 }
147 None
148}
149
150fn is_wc(b: u8) -> bool {
151 b.is_ascii_alphanumeric() || b == b'_'
152}
153
154fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
155 let before = &source[..offset.min(source.len())];
156 let line = before.chars().filter(|&c| c == '\n').count() + 1;
157 let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
158 (line, col)
159}