Skip to main content

sqrust_rules/convention/
no_using_clause.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Join, JoinConstraint, JoinOperator, Query, Select, SetExpr, Statement,
3    TableFactor, TableWithJoins};
4
5pub struct NoUsingClause;
6
7/// Converts a byte offset to a 1-indexed (line, col) pair.
8fn line_col(source: &str, offset: usize) -> (usize, usize) {
9    let before = &source[..offset];
10    let line = before.chars().filter(|&c| c == '\n').count() + 1;
11    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
12    (line, col)
13}
14
15/// Returns `true` if `ch` is a SQL word character (`[a-zA-Z0-9_]`).
16#[inline]
17fn is_word_char(ch: u8) -> bool {
18    ch.is_ascii_alphanumeric() || ch == b'_'
19}
20
21/// Builds a skip table: `true` at every byte inside strings, comments, or
22/// quoted identifiers.
23fn build_skip(bytes: &[u8]) -> Vec<bool> {
24    let len = bytes.len();
25    let mut skip = vec![false; len];
26    let mut i = 0;
27
28    while i < len {
29        // Line comment: -- ... newline
30        if i + 1 < len && bytes[i] == b'-' && bytes[i + 1] == b'-' {
31            skip[i] = true;
32            skip[i + 1] = true;
33            i += 2;
34            while i < len && bytes[i] != b'\n' {
35                skip[i] = true;
36                i += 1;
37            }
38            continue;
39        }
40
41        // Block comment: /* ... */
42        if i + 1 < len && bytes[i] == b'/' && bytes[i + 1] == b'*' {
43            skip[i] = true;
44            skip[i + 1] = true;
45            i += 2;
46            while i < len {
47                if i + 1 < len && bytes[i] == b'*' && bytes[i + 1] == b'/' {
48                    skip[i] = true;
49                    skip[i + 1] = true;
50                    i += 2;
51                    break;
52                }
53                skip[i] = true;
54                i += 1;
55            }
56            continue;
57        }
58
59        // Single-quoted string: '...' with '' escape
60        if bytes[i] == b'\'' {
61            skip[i] = true;
62            i += 1;
63            while i < len {
64                if bytes[i] == b'\'' {
65                    skip[i] = true;
66                    i += 1;
67                    if i < len && bytes[i] == b'\'' {
68                        skip[i] = true;
69                        i += 1;
70                        continue;
71                    }
72                    break;
73                }
74                skip[i] = true;
75                i += 1;
76            }
77            continue;
78        }
79
80        // Double-quoted identifier: "..."
81        if bytes[i] == b'"' {
82            skip[i] = true;
83            i += 1;
84            while i < len && bytes[i] != b'"' {
85                skip[i] = true;
86                i += 1;
87            }
88            if i < len {
89                skip[i] = true;
90                i += 1;
91            }
92            continue;
93        }
94
95        // Backtick identifier: `...`
96        if bytes[i] == b'`' {
97            skip[i] = true;
98            i += 1;
99            while i < len && bytes[i] != b'`' {
100                skip[i] = true;
101                i += 1;
102            }
103            if i < len {
104                skip[i] = true;
105                i += 1;
106            }
107            continue;
108        }
109
110        i += 1;
111    }
112
113    skip
114}
115
116/// Finds the Nth occurrence (0-indexed) of the `USING` keyword (case-insensitive,
117/// word-boundary) in `source` outside strings/comments.
118///
119/// Returns the byte offset of the `U` in `USING`, or `None` if not found.
120fn find_nth_using(source: &str, skip: &[bool], occurrence: usize) -> Option<usize> {
121    let bytes = source.as_bytes();
122    let len = bytes.len();
123    let pattern = b"USING";
124    let pat_len = pattern.len();
125    let mut count = 0;
126    let mut i = 0;
127
128    while i + pat_len <= len {
129        if skip[i] {
130            i += 1;
131            continue;
132        }
133
134        // Case-insensitive match
135        let matches = bytes[i..i + pat_len]
136            .iter()
137            .zip(pattern.iter())
138            .all(|(&a, &b)| a.eq_ignore_ascii_case(&b));
139
140        if matches {
141            let all_code = (i..i + pat_len).all(|k| !skip[k]);
142            if all_code {
143                let boundary_before = i == 0 || !is_word_char(bytes[i - 1]);
144                let end = i + pat_len;
145                let boundary_after = end >= len || !is_word_char(bytes[end]);
146                if boundary_before && boundary_after {
147                    if count == occurrence {
148                        return Some(i);
149                    }
150                    count += 1;
151                    i += pat_len;
152                    continue;
153                }
154            }
155        }
156
157        i += 1;
158    }
159
160    None
161}
162
163/// Returns `true` if the join has a `USING(...)` constraint.
164fn join_has_using(join: &Join) -> bool {
165    let constraint = match &join.join_operator {
166        JoinOperator::Inner(c) => Some(c),
167        JoinOperator::LeftOuter(c) => Some(c),
168        JoinOperator::RightOuter(c) => Some(c),
169        JoinOperator::FullOuter(c) => Some(c),
170        JoinOperator::Semi(c) => Some(c),
171        JoinOperator::LeftSemi(c) => Some(c),
172        JoinOperator::RightSemi(c) => Some(c),
173        JoinOperator::Anti(c) => Some(c),
174        JoinOperator::LeftAnti(c) => Some(c),
175        JoinOperator::RightAnti(c) => Some(c),
176        JoinOperator::CrossJoin
177        | JoinOperator::CrossApply
178        | JoinOperator::OuterApply
179        | JoinOperator::AsOf { .. } => None,
180    };
181    matches!(constraint, Some(JoinConstraint::Using(_)))
182}
183
184/// Recurses into a `TableFactor` to find derived-table subqueries.
185fn collect_from_table_factor(
186    factor: &TableFactor,
187    source: &str,
188    skip: &[bool],
189    using_count: &mut usize,
190    diags: &mut Vec<Diagnostic>,
191) {
192    if let TableFactor::Derived { subquery, .. } = factor {
193        collect_from_query(subquery, source, skip, using_count, diags);
194    }
195}
196
197/// Collects USING violations from a list of `TableWithJoins` items.
198/// `using_count` tracks how many USING occurrences in the source we've consumed,
199/// so each violation maps to the correct text position.
200///
201/// For each `TableWithJoins` we first recurse into the relation (which may be
202/// a derived subquery), then into each join's relation, and then check whether
203/// the join itself uses `USING`.
204fn collect_from_table_with_joins(
205    tables: &[TableWithJoins],
206    source: &str,
207    skip: &[bool],
208    using_count: &mut usize,
209    diags: &mut Vec<Diagnostic>,
210) {
211    for twj in tables {
212        // Recurse into the primary relation (e.g. a derived subquery).
213        collect_from_table_factor(&twj.relation, source, skip, using_count, diags);
214
215        for join in &twj.joins {
216            // Recurse into the join's relation (may itself be a subquery).
217            collect_from_table_factor(&join.relation, source, skip, using_count, diags);
218
219            if join_has_using(join) {
220                if let Some(offset) = find_nth_using(source, skip, *using_count) {
221                    let (line, col) = line_col(source, offset);
222                    diags.push(Diagnostic {
223                        rule: "Convention/NoUsingClause",
224                        message:
225                            "JOIN USING clause found; prefer explicit ON conditions for clarity"
226                                .to_string(),
227                        line,
228                        col,
229                    });
230                }
231                *using_count += 1;
232            }
233        }
234    }
235}
236
237fn collect_from_select(
238    select: &Select,
239    source: &str,
240    skip: &[bool],
241    using_count: &mut usize,
242    diags: &mut Vec<Diagnostic>,
243) {
244    collect_from_table_with_joins(&select.from, source, skip, using_count, diags);
245}
246
247fn collect_from_set_expr(
248    expr: &SetExpr,
249    source: &str,
250    skip: &[bool],
251    using_count: &mut usize,
252    diags: &mut Vec<Diagnostic>,
253) {
254    match expr {
255        SetExpr::Select(select) => collect_from_select(select, source, skip, using_count, diags),
256        SetExpr::Query(inner) => collect_from_query(inner, source, skip, using_count, diags),
257        SetExpr::SetOperation { left, right, .. } => {
258            collect_from_set_expr(left, source, skip, using_count, diags);
259            collect_from_set_expr(right, source, skip, using_count, diags);
260        }
261        _ => {}
262    }
263}
264
265fn collect_from_query(
266    query: &Query,
267    source: &str,
268    skip: &[bool],
269    using_count: &mut usize,
270    diags: &mut Vec<Diagnostic>,
271) {
272    // CTEs
273    if let Some(with) = &query.with {
274        for cte in &with.cte_tables {
275            collect_from_query(&cte.query, source, skip, using_count, diags);
276        }
277    }
278    collect_from_set_expr(&query.body, source, skip, using_count, diags);
279}
280
281impl Rule for NoUsingClause {
282    fn name(&self) -> &'static str {
283        "Convention/NoUsingClause"
284    }
285
286    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
287        if !ctx.parse_errors.is_empty() {
288            return Vec::new();
289        }
290
291        let source = &ctx.source;
292        let bytes = source.as_bytes();
293        let skip = build_skip(bytes);
294        let mut diags = Vec::new();
295        let mut using_count = 0usize;
296
297        for stmt in &ctx.statements {
298            if let Statement::Query(query) = stmt {
299                collect_from_query(query, source, &skip, &mut using_count, &mut diags);
300            }
301        }
302
303        diags
304    }
305}