Skip to main content

sqrust_rules/ambiguous/
self_join.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{JoinConstraint, JoinOperator, Query, SetExpr, Statement, TableFactor};
3use std::collections::HashMap;
4
5pub struct SelfJoin;
6
7impl Rule for SelfJoin {
8    fn name(&self) -> &'static str {
9        "Ambiguous/SelfJoin"
10    }
11
12    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
13        if !ctx.parse_errors.is_empty() {
14            return Vec::new();
15        }
16
17        let mut diags = Vec::new();
18        for stmt in &ctx.statements {
19            if let Statement::Query(query) = stmt {
20                check_query(query, &ctx.source, &mut diags);
21            }
22        }
23        diags
24    }
25}
26
27fn check_query(query: &Query, source: &str, diags: &mut Vec<Diagnostic>) {
28    if let Some(with) = &query.with {
29        for cte in &with.cte_tables {
30            check_query(&cte.query, source, diags);
31        }
32    }
33    check_set_expr(&query.body, source, diags);
34}
35
36fn check_set_expr(expr: &SetExpr, source: &str, diags: &mut Vec<Diagnostic>) {
37    match expr {
38        SetExpr::Select(sel) => {
39            for twj in &sel.from {
40                // Collect all table references in this FROM + JOIN chain.
41                let mut refs: Vec<(String, Option<String>)> = Vec::new();
42
43                collect_table_ref(&twj.relation, &mut refs, source, diags);
44                for join in &twj.joins {
45                    collect_table_ref(&join.relation, &mut refs, source, diags);
46
47                    // Also collect ON-clause subqueries.
48                    if let JoinOperator::Inner(JoinConstraint::On(_))
49                    | JoinOperator::LeftOuter(JoinConstraint::On(_))
50                    | JoinOperator::RightOuter(JoinConstraint::On(_))
51                    | JoinOperator::FullOuter(JoinConstraint::On(_)) = &join.join_operator
52                    {
53                        // ON expressions don't contain subqueries we need to recurse into here.
54                    }
55                }
56
57                // Detect self-joins: same table name used twice without distinct aliases.
58                detect_self_joins(&refs, source, diags);
59            }
60
61            // Recurse into subqueries in FROM (Derived table factors).
62            for twj in &sel.from {
63                recurse_subqueries_in_factor(&twj.relation, source, diags);
64                for join in &twj.joins {
65                    recurse_subqueries_in_factor(&join.relation, source, diags);
66                }
67            }
68        }
69        SetExpr::Query(inner) => check_query(inner, source, diags),
70        SetExpr::SetOperation { left, right, .. } => {
71            check_set_expr(left, source, diags);
72            check_set_expr(right, source, diags);
73        }
74        _ => {}
75    }
76}
77
78/// Appends a (table_name_lowercase, alias_name_or_none) entry for each
79/// `TableFactor::Table` found in `factor`. Derived factors (subqueries) are
80/// NOT included here — they are handled separately via `recurse_subqueries_in_factor`.
81fn collect_table_ref(
82    factor: &TableFactor,
83    refs: &mut Vec<(String, Option<String>)>,
84    _source: &str,
85    _diags: &mut Vec<Diagnostic>,
86) {
87    if let TableFactor::Table { name, alias, .. } = factor {
88        let table_name = name
89            .0
90            .last()
91            .map(|i| i.value.to_lowercase())
92            .unwrap_or_default();
93        let alias_name = alias.as_ref().map(|a| a.name.value.to_lowercase());
94        refs.push((table_name, alias_name));
95    }
96}
97
98/// Recurse into derived (subquery) table factors.
99fn recurse_subqueries_in_factor(
100    factor: &TableFactor,
101    source: &str,
102    diags: &mut Vec<Diagnostic>,
103) {
104    if let TableFactor::Derived { subquery, .. } = factor {
105        check_query(subquery, source, diags);
106    }
107}
108
109/// Given a list of (table_name, alias_option) pairs from a single FROM clause,
110/// find cases where the same table appears twice and at least one occurrence
111/// has no alias, or both have the same alias.
112fn detect_self_joins(
113    refs: &[(String, Option<String>)],
114    source: &str,
115    diags: &mut Vec<Diagnostic>,
116) {
117    // Group by table name.
118    // For each table, collect all aliases (None = no alias).
119    let mut by_name: HashMap<&str, Vec<Option<&str>>> = HashMap::new();
120    for (name, alias) in refs {
121        by_name
122            .entry(name.as_str())
123            .or_default()
124            .push(alias.as_deref());
125    }
126
127    for (table_name, aliases) in &by_name {
128        if aliases.len() < 2 {
129            continue;
130        }
131
132        // Self-join is ambiguous when at least one occurrence has no alias
133        // OR when any two occurrences share the same alias.
134        let is_ambiguous = aliases.iter().any(|a| a.is_none())
135            || {
136                // Check for duplicate aliases.
137                let named: Vec<&str> = aliases.iter().filter_map(|a| *a).collect();
138                has_duplicate(&named)
139            };
140
141        if is_ambiguous {
142            // Find the second occurrence of the table name in the source text.
143            let pos = find_second_occurrence(source, table_name);
144            let (line, col) = offset_to_line_col(source, pos);
145            diags.push(Diagnostic {
146                rule: "Ambiguous/SelfJoin",
147                message: format!(
148                    "Table '{}' is joined to itself without distinct aliases",
149                    table_name
150                ),
151                line,
152                col,
153            });
154        }
155    }
156}
157
158/// Returns true if any value in `names` appears more than once.
159fn has_duplicate(names: &[&str]) -> bool {
160    for i in 0..names.len() {
161        for j in (i + 1)..names.len() {
162            if names[i] == names[j] {
163                return true;
164            }
165        }
166    }
167    false
168}
169
170/// Finds the byte offset of the second whole-word, case-insensitive occurrence
171/// of `name` in `source`. Falls back to 0 (will resolve to line 1, col 1).
172fn find_second_occurrence(source: &str, name: &str) -> usize {
173    find_nth_occurrence(source, name, 1)
174}
175
176/// Finds the byte offset of the `nth` (0-indexed) whole-word, case-insensitive
177/// occurrence of `name` in `source`. Falls back to 0 if fewer occurrences exist.
178fn find_nth_occurrence(source: &str, name: &str, nth: usize) -> usize {
179    let bytes = source.as_bytes();
180    let name_bytes: Vec<u8> = name.bytes().map(|b| b.to_ascii_uppercase()).collect();
181    let name_len = name_bytes.len();
182    let src_len = bytes.len();
183
184    let mut count = 0usize;
185    let mut i = 0usize;
186
187    while i + name_len <= src_len {
188        let before_ok = i == 0 || {
189            let b = bytes[i - 1];
190            !b.is_ascii_alphanumeric() && b != b'_'
191        };
192
193        if before_ok {
194            let matches = bytes[i..i + name_len]
195                .iter()
196                .zip(name_bytes.iter())
197                .all(|(&a, &b)| a.to_ascii_uppercase() == b);
198
199            if matches {
200                let after = i + name_len;
201                let after_ok = after >= src_len || {
202                    let b = bytes[after];
203                    !b.is_ascii_alphanumeric() && b != b'_'
204                };
205
206                if after_ok {
207                    if count == nth {
208                        return i;
209                    }
210                    count += 1;
211                }
212            }
213        }
214
215        i += 1;
216    }
217
218    0
219}
220
221/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
222fn offset_to_line_col(source: &str, offset: usize) -> (usize, usize) {
223    let before = &source[..offset];
224    let line = before.chars().filter(|&c| c == '\n').count() + 1;
225    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
226    (line, col)
227}