Skip to main content

sqrust_rules/structure/
natural_join.rs

1use sqrust_core::{Diagnostic, FileContext, Rule};
2use sqlparser::ast::{Join, JoinConstraint, JoinOperator, Query, SetExpr, Statement, TableFactor};
3
4pub struct NaturalJoin;
5
6impl Rule for NaturalJoin {
7    fn name(&self) -> &'static str {
8        "NaturalJoin"
9    }
10
11    fn check(&self, ctx: &FileContext) -> Vec<Diagnostic> {
12        if !ctx.parse_errors.is_empty() {
13            return Vec::new();
14        }
15
16        let mut diags = Vec::new();
17
18        for stmt in &ctx.statements {
19            if let Statement::Query(query) = stmt {
20                check_query(query, &ctx.source, self.name(), &mut diags);
21            }
22        }
23
24        diags
25    }
26}
27
28// ── AST walking ───────────────────────────────────────────────────────────────
29
30fn check_query(query: &Query, source: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
31    // Recurse into CTEs.
32    if let Some(with) = &query.with {
33        for cte in &with.cte_tables {
34            check_query(&cte.query, source, rule, diags);
35        }
36    }
37    check_set_expr(&query.body, source, rule, diags);
38}
39
40fn check_set_expr(expr: &SetExpr, source: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
41    match expr {
42        SetExpr::Select(sel) => {
43            for twj in &sel.from {
44                for join in &twj.joins {
45                    check_join(join, source, rule, diags);
46                }
47                recurse_table_factor(&twj.relation, source, rule, diags);
48                for join in &twj.joins {
49                    recurse_table_factor(&join.relation, source, rule, diags);
50                }
51            }
52        }
53        SetExpr::SetOperation { left, right, .. } => {
54            check_set_expr(left, source, rule, diags);
55            check_set_expr(right, source, rule, diags);
56        }
57        SetExpr::Query(inner) => {
58            check_query(inner, source, rule, diags);
59        }
60        _ => {}
61    }
62}
63
64/// Returns true if the `JoinOperator` uses `JoinConstraint::Natural`.
65///
66/// In sqlparser 0.53, NATURAL JOIN is not a standalone `JoinOperator` variant.
67/// Instead it is expressed as any directional join variant with a `Natural`
68/// constraint, e.g. `Inner(JoinConstraint::Natural)` for a plain
69/// `NATURAL JOIN`, or `LeftOuter(JoinConstraint::Natural)` for
70/// `NATURAL LEFT JOIN`, etc.
71fn is_natural(op: &JoinOperator) -> bool {
72    match op {
73        JoinOperator::Inner(c) => matches!(c, JoinConstraint::Natural),
74        JoinOperator::LeftOuter(c) => matches!(c, JoinConstraint::Natural),
75        JoinOperator::RightOuter(c) => matches!(c, JoinConstraint::Natural),
76        JoinOperator::FullOuter(c) => matches!(c, JoinConstraint::Natural),
77        _ => false,
78    }
79}
80
81fn check_join(join: &Join, source: &str, rule: &'static str, diags: &mut Vec<Diagnostic>) {
82    if is_natural(&join.join_operator) {
83        let (line, col) = find_keyword_pos(source, "NATURAL");
84        diags.push(Diagnostic {
85            rule,
86            message: "NATURAL JOIN depends on column naming conventions; use explicit JOIN ON instead".to_string(),
87            line,
88            col,
89        });
90    }
91}
92
93fn recurse_table_factor(
94    tf: &TableFactor,
95    source: &str,
96    rule: &'static str,
97    diags: &mut Vec<Diagnostic>,
98) {
99    if let TableFactor::Derived { subquery, .. } = tf {
100        check_query(subquery, source, rule, diags);
101    }
102}
103
104// ── keyword position helper ───────────────────────────────────────────────────
105
106/// Find the first occurrence of a keyword (case-insensitive, word-boundary)
107/// in `source`. Returns a 1-indexed (line, col) pair. Falls back to (1, 1) if
108/// not found.
109fn find_keyword_pos(source: &str, keyword: &str) -> (usize, usize) {
110    let upper = source.to_uppercase();
111    let kw_upper = keyword.to_uppercase();
112    let kw_len = kw_upper.len();
113    let bytes = upper.as_bytes();
114    let len = bytes.len();
115
116    let mut pos = 0;
117    while pos + kw_len <= len {
118        if let Some(rel) = upper[pos..].find(kw_upper.as_str()) {
119            let abs = pos + rel;
120
121            // Word boundary check.
122            let before_ok = abs == 0 || {
123                let b = bytes[abs - 1];
124                !b.is_ascii_alphanumeric() && b != b'_'
125            };
126            let after = abs + kw_len;
127            let after_ok = after >= len || {
128                let b = bytes[after];
129                !b.is_ascii_alphanumeric() && b != b'_'
130            };
131
132            if before_ok && after_ok {
133                return line_col(source, abs);
134            }
135
136            pos = abs + 1;
137        } else {
138            break;
139        }
140    }
141
142    (1, 1)
143}
144
145/// Converts a byte offset in `source` to a 1-indexed (line, col) pair.
146fn line_col(source: &str, offset: usize) -> (usize, usize) {
147    let before = &source[..offset];
148    let line = before.chars().filter(|&c| c == '\n').count() + 1;
149    let col = before.rfind('\n').map(|p| offset - p - 1).unwrap_or(offset) + 1;
150    (line, col)
151}